Running Huggingface DistilBERT with TensorFlow-Neuron#

In this tutorial you will compile and deploy DistilBERT version of HuggingFace 🤗 Transformers BERT for Inferentia using TensorFlow-Neuron. The full list of HuggingFace’s pretrained BERT models can be found in the BERT section on this page https://huggingface.co/transformers/pretrained_models.html. you can also read about HuggingFace’s pipeline feature here: https://huggingface.co/transformers/main_classes/pipelines.html

This Jupyter notebook should be run on an instance which is inf1.6xlarge or larger, but in real life scenario the compilation should be done on a compute instance and the deployment on inf1 instance to save costs.

Setup#

To run this tutorial please follow the instructions for TensorFlow-Neuron Setup and the Jupyter Notebook Quickstart and set your kernel to “Python (tensorflow-neuron)” .

Next, install some additional dependencies.

[ ]:
%env TOKENIZERS_PARALLELISM=True #Supresses tokenizer warnings making errors easier to detect
!pip install transformers==4.30.2
!pip install ipywidgets

Download From Huggingface and Compile for AWS-Neuron#

[ ]:
import tensorflow as tf
import tensorflow_neuron as tfn
from transformers import DistilBertTokenizer, TFDistilBertModel

# Create a wrapper for the roberta model that will accept inputs as a list
# instead of a dictionary. This will allow the compiled model to be saved
# to disk with the model.save() fucntion.
class DistilBertWrapper(tf.keras.Model):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def __call__(self, example_inputs):
        return self.model({'input_ids' : example_inputs[0], 'attention_mask' : example_inputs[1]})


tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased-finetuned-sst-2-english')
model = DistilBertWrapper(TFDistilBertModel.from_pretrained('distilbert-base-uncased-finetuned-sst-2-english'))

batch_size = 16

# create example inputs with a batch size of 16
text = ["Paris is the <mask> of France."] * batch_size
encoded_input = tokenizer(text, return_tensors='tf', padding='max_length', max_length=64)

# turn inputs into a list
example_input = [encoded_input['input_ids'], encoded_input['attention_mask']]

#compile
model_neuron = tfn.trace(model, example_input)

print("Running on neuron:", model_neuron(example_input))

# save the model to disk to save recompilation time for next usage
model_neuron.save('./distilbert-neuron-b16')

Run Basic Inference Benchmarking#

[ ]:
import numpy as np
import concurrent.futures
import time

reloaded_neuron_model = tf.keras.models.load_model('./distilbert-neuron-b16')
print("Reloaded model running on neuron:", reloaded_neuron_model(example_input))

num_threads = 4
num_inferences = 1000

latency_list = []
def inference_with_latency_calculation(example_input):
    global latency_list
    start = time.time()
    result = reloaded_neuron_model(example_input)
    end = time.time()
    latency_list.append((end-start) * 1000)
    return result

start = time.time()
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
    futures = []
    for i in range(num_inferences):
        futures.append(executor.submit(inference_with_latency_calculation, example_input))
    for future in concurrent.futures.as_completed(futures):
        get_result = future.result()
end = time.time()

total_time = end - start
throughput = (num_inferences * batch_size)/total_time

print(f"Throughput was {throughput} samples per second.")
print(f"Latency p50 was {np.percentile(latency_list, 50)} ms")
print(f"Latency p90 was {np.percentile(latency_list, 90)} ms")
print(f"Latency p95 was {np.percentile(latency_list, 95)} ms")
print(f"Latency p99 was {np.percentile(latency_list, 99)} ms")
assert(throughput >= 1930.0)
[ ]: