Compiling and Deploying Pretrained HuggingFace Pipelines distilBERT with Tensorflow2 Neuron
Contents
Compiling and Deploying Pretrained HuggingFace Pipelines distilBERT with Tensorflow2 Neuron#
Introduction#
In this tutorial you will compile and deploy distilBERT version of HuggingFace 🤗 Transformers BERT for Inferentia. The full list of HuggingFace’s pretrained BERT models can be found in the BERT section on this page https://huggingface.co/transformers/pretrained_models.html. you can also read about HuggingFace’s pipeline feature here: https://huggingface.co/transformers/main_classes/pipelines.html
This Jupyter notebook should be run on an instance which is inf1.6xlarge or larger, but in real life scenario the compilation should be done on a compute instance and the deployment on inf1 instance to save costs.
Setting up your environment:#
To run this tutorial, please make sure you deactivate any existing TensorFlow conda environments you already using. Install TensorFlow 2.x by following the instructions at TensorFlow Tutorial Setup Guide.
After following the Setup Guide, you need to change your kernel to Python (Neuron TensorFlow 2)
by clicking Kerenel->Change Kernel->Python (Neuron TensorFlow 2)
Now you can install TensorFlow Neuron 2.x, HuggingFace transformers, and HuggingFace datasets dependencies here.
[ ]:
!pip install --upgrade "transformers==4.1.0"
!pip install ipywidgets
[ ]:
from transformers import pipeline
import tensorflow as tf
import tensorflow.neuron as tfn
Compile the model into an AWS Neuron Optimized Model#
[ ]:
#Create the huggingface pipeline for sentiment analysis
#this model tries to determine of the input text has a positive
#or a negative sentiment.
model_name = 'distilbert-base-uncased-finetuned-sst-2-english'
pipe = pipeline('sentiment-analysis', model=model_name, framework='tf')
#pipelines are extremely easy to use as they do all the tokenization,
#inference and output interpretation for you.
pipe(['I love pipelines, they are very easy to use!', 'this string makes it batch size two'])
As yo’ve seen above, Huggingface’s pipline feature is a great wrapper for running inference on their models. It takes care of the tokenization of the string inputs. Then feeds that tokenized input to the model. Finally it interprets the outputs of the model and formats them in a way that is very human readable. Our goal will be to compile the underlying model inside the pipeline as well as make some edits to the tokenizer. The reason you need to edit the tokenizer is to make sure that you have a standard sequence length (in this case 128) as neuron only accepts static input shapes.
[ ]:
neuron_pipe = pipeline('sentiment-analysis', model=model_name, framework='tf')
#the first step is to modify the underlying tokenizer to create a static
#input shape as inferentia does not work with dynamic input shapes
original_tokenizer = pipe.tokenizer
#you intercept the function call to the original tokenizer
#and inject our own code to modify the arguments
def wrapper_function(*args, **kwargs):
kwargs['padding'] = 'max_length'
#this is the key line here to set a static input shape
#so that all inputs are set to a len of 128
kwargs['max_length'] = 128
kwargs['truncation'] = True
kwargs['return_tensors'] = 'tf'
return original_tokenizer(*args, **kwargs)
#insert our wrapper function as the new tokenizer as well
#as reinserting back some attribute information that was lost
#when you replaced the original tokenizer with our wrapper function
neuron_pipe.tokenizer = wrapper_function
neuron_pipe.tokenizer.decode = original_tokenizer.decode
neuron_pipe.tokenizer.mask_token_id = original_tokenizer.mask_token_id
neuron_pipe.tokenizer.pad_token_id = original_tokenizer.pad_token_id
neuron_pipe.tokenizer.convert_ids_to_tokens = original_tokenizer.convert_ids_to_tokens
#Our example data!
string_inputs = [
'I love to eat pizza!',
'I am sorry. I really want to like it, but I just can not stand sushi.',
'I really do not want to type out 128 strings to create batch 128 data.',
'Ah! Multiplying this list by 32 would be a great solution!',
]
string_inputs = string_inputs * 32
example_inputs = neuron_pipe.tokenizer(string_inputs)
#compile the model by calling tfn.trace by passing in the underlying model
#and the example inputs generated by our updated tokenizer
neuron_model = tfn.trace(pipe.model, example_inputs)
#now you can insert the neuron_model and replace the cpu model
#so now you have a huggingface pipeline that uses and underlying neuron model!
neuron_pipe.model = neuron_model
neuron_pipe.model.config = pipe.model.config
Why use batch size 128?#
You’ll notice that in the above example we passed a two tensors of shape 128 (the batch size) x 128 (the sequence length) in this function call tfn.trace(pipe.model, example_inputs)
. The example_inputs argument is important to tfn.trace
because it tells the neuron model what to expect (remember that a neuron model needs static input shapes, so example_inputs defines that static input shape). A smaller batch size would also compile, but a large batch size ensures that the neuron hardware
will be fed enough data to be as performant as possible.
What if my model isn’t a Huggingface pipeline?#
Not to worry! There is no requirement that your model needs to be Huggingface pipeline compatible. The Huggingface pipeline is just a wrapper for an underlying TensorFlow model (in our case pipe.model
). As long as you have a TensorFlow 2.x model you can compile it on neuron by calling tfn.trace(your_model, example_inputs)
. The processing the input and output to your own model is up to you! Take a look at the example below to see what happens when we call the model without the Huggingface
pipeline wrapper as opposed to with it.
[ ]:
#directly call the model
print(neuron_model(example_inputs))
#with the model inserted to the wrapper
print(neuron_pipe(string_inputs))
#Look at the difference between string_inputs
#and example_inputs
print(example_inputs)
print(string_inputs)
Save your neuron model to disk and avoid recompilation.#
To avoid recompiling the model before every deployment, you can save the neuron model by calling model_neuron.save(model_dir)
. This save
method prefers to work on a flat input/output lists and does not work on dictionary input/output - which is what the Huggingface distilBERT expects as input. You can work around this by writing a simple wrapper that takes in an input list instead of a dictionary, compile the wrapped model and save it for later use.
[ ]:
class TFBertForSequenceClassificationFlatIO(tf.keras.Model):
def __init__(self, model):
super().__init__()
self.model = model
def call(self, inputs):
input_ids, attention_mask = inputs
output = self.model({'input_ids': input_ids, 'attention_mask': attention_mask})
return output['logits']
#wrap the original model from HuggingFace, now our model accepts a list as input
model_wrapped = TFBertForSequenceClassificationFlatIO(pipe.model)
#turn the dictionary input into list input
example_inputs_list = [example_inputs['input_ids'], example_inputs['attention_mask']]
#compile the wrapped model and save it to disk
model_wrapped_traced = tfn.trace(model_wrapped, example_inputs_list)
model_wrapped_traced.save('./distilbert_b128')
[ ]:
!ls #you should now be able to see the model saved as the folder distilbert_b128
Load the model from disk#
Now you can reload the model by calling tf.keras.models.load_model(str : model_directory)
. This model is already compiled and could run inference on neuron, but if you want it to work with our Huggingface pipeline, you have to wrap it again to accept dictionary input.
[ ]:
class TFBertForSequenceClassificationDictIO(tf.keras.Model):
def __init__(self, model_wrapped):
super().__init__()
self.model_wrapped = model_wrapped
self.aws_neuron_function = model_wrapped.aws_neuron_function
def call(self, inputs):
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
logits = self.model_wrapped([input_ids, attention_mask])
return [logits]
reloaded_model = tf.keras.models.load_model('./distilbert_b128')
rewrapped_model = TFBertForSequenceClassificationDictIO(model_wrapped_traced)
#now you can reinsert our reloaded model back into our pipeline
neuron_pipe.model = rewrapped_model
neuron_pipe.model.config = pipe.model.config
Benchmarking the neuron model#
Now you can do some simple benchmarking of the neuron model. If you are running this tutorial on a inf1.6xlarge, as suggested, you must tell neuron to use all 16 Neuron Cores to get maximum throughput. By default, TensorFlow Neuron will use only one Inferentia chip, which has 4 Neuron Cores. An inf1.6xlarge has 4 Inferentia chips. To tell Neuron to run on all available cores, you can set the environment variable NEURONCORE_GROUP_SIZES
or launch multiple processes that query the same model.
To read more about this refer to the Neuron Core groups section of our documentation. Use a warmup inference on the neuron model before benchmarking, as the first inference call also loads the model onto inferentia.
[ ]:
import warnings
warnings.warn("NEURONCORE_GROUP_SIZES is being deprecated, if your application is using NEURONCORE_GROUP_SIZES please \
see https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/deprecation.html#announcing-end-of-support-for-neuroncore-group-sizes \
for more details.", DeprecationWarning)
%env NEURONCORE_GROUP_SIZES='16x1'
import time
#warmup inf
neuron_pipe(string_inputs)
#benchmark batch 128 neuron model
neuron_b128_times = []
for i in range(1000):
start = time.time()
outputs = neuron_pipe(string_inputs)
end = time.time()
neuron_b128_times.append(end - start)
neuron_b128_times = sorted(neuron_b128_times)
print(f"Average throughput for batch 128 neuron model is {128/(sum(neuron_b128_times)/len(neuron_b128_times))} sentences/s.")
print(f"Peak throughput for batch 128 neuron model is {128/min(neuron_b128_times)} sentences/s.")
print()
print(f"50th percentile latency for batch 128 neuron model is {neuron_b128_times[int(1000*.5)] * 1000} ms.")
print(f"90th percentile latency for batch 128 neuron model is {neuron_b128_times[int(1000*.9)] * 1000} ms.")
print(f"95th percentile latency for bacth 128 neuron model is {neuron_b128_times[int(1000*.95)] * 1000} ms.")
print(f"99th percentile latency for batch 128 neuron model is {neuron_b128_times[int(1000*.99)] * 1000} ms.")
print()
[ ]: