Running Huggingface Roberta-Base with TensorFlow-NeuronX#
This tutorial demonstrates how to compile the Huggingface roberta-base model and infer on a trn1.2xlarge instance with tensorflow-neuronx
. To compile larger models like roberta-large, please consider using an inf2 instance.
Setup#
To run this tutorial please follow the instructions for TensorFlow-NeuronX Setup and the Jupyter Notebook Quickstart and set your kernel to “Python (tensorflow-neuronx)”.
Next, install some additional dependencies.
[ ]:
%env TOKENIZERS_PARALLELISM=True #Supresses tokenizer warnings making errors easier to detect
!pip install transformers
Download From Huggingface and Compile for AWS-Neuron#
[ ]:
import tensorflow as tf
import tensorflow_neuronx as tfnx
from transformers import RobertaTokenizer, TFRobertaModel
from transformers import BertTokenizer, TFBertModel
# Create a wrapper for the roberta model that will accept inputs as a list
# instead of a dictionary. This will allow the compiled model to be saved
# to disk with the model.save() fucntion.
class RobertaWrapper(tf.keras.Model):
def __init__(self, model):
super().__init__()
self.model = model
def __call__(self, example_inputs):
return self.model({'input_ids' : example_inputs[0], 'attention_mask' : example_inputs[1]})
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaWrapper(TFRobertaModel.from_pretrained('roberta-base'))
batch_size = 16
# create example inputs with a batch size of 16
text = ["Paris is the <mask> of France."] * batch_size
encoded_input = tokenizer(text, return_tensors='tf', padding='max_length', max_length=64)
# turn inputs into a list
example_input = [encoded_input['input_ids'], encoded_input['attention_mask']]
#compile
model_neuron = tfnx.trace(model, example_input)
print("Running on neuron:", model_neuron(example_input))
# save the model to disk to save recompilation time for next usage
model_neuron.save('./roberta-neuron-b16')
Run Basic Inference Benchmarking#
[ ]:
import numpy as np
import concurrent.futures
import time
reloaded_neuron_model = tf.keras.models.load_model('./roberta-neuron-b16')
print("Reloaded model running on neuron:", reloaded_neuron_model(example_input))
num_threads = 4
num_inferences = 1000
latency_list = []
def inference_with_latency_calculation(example_input):
global latency_list
start = time.time()
result = reloaded_neuron_model(example_input)
end = time.time()
latency_list.append((end-start) * 1000)
return result
start = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = []
for i in range(num_inferences):
futures.append(executor.submit(inference_with_latency_calculation, example_input))
for future in concurrent.futures.as_completed(futures):
get_result = future.result()
end = time.time()
total_time = end - start
print(f"Throughput was {(num_inferences * batch_size)/total_time} samples per second.")
print(f"Latency p50 was {np.percentile(latency_list, 50)} ms")
print(f"Latency p90 was {np.percentile(latency_list, 90)} ms")
print(f"Latency p99 was {np.percentile(latency_list, 99)} ms")