Compiling and Deploying HuggingFace Pretrained BERT#

Introduction#

In this tutorial we will compile and deploy BERT-base version of HuggingFace 🤗 Transformers BERT for Inferentia. The full list of HuggingFace’s pretrained BERT models can be found in the BERT section on this page https://huggingface.co/transformers/pretrained_models.html.

This Jupyter notebook should be run on an instance which is inf1.6xlarge or larger. The compile part of this tutorial requires inf1.6xlarge and not the inference itself. For simplicity we will run this tutorial on inf1.6xlarge but in real life scenario the compilation should be done on a compute instance and the deployment on inf1 instance to save costs.

Verify that this Jupyter notebook is running the Python kernel environment that was set up according to the PyTorch Installation Guide. You can select the kernel from the “Kernel -> Change Kernel” option on the top of this Jupyter notebook page.

Install Dependencies:#

This tutorial requires the following pip packages:

  • torch-neuron

  • neuron-cc[tensorflow]

  • transformers

Most of these packages will be installed when configuring your environment using the Neuron PyTorch setup guide. The additional dependencies must be installed here.

[ ]:
%env TOKENIZERS_PARALLELISM=True #Supresses tokenizer warnings making errors easier to detect
!pip install --upgrade "transformers==4.6.0"

Compile the model into an AWS Neuron optimized TorchScript#

[ ]:
import tensorflow  # to workaround a protobuf version conflict issue
import torch
import torch.neuron
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
import transformers
import os
import warnings

# Setting up NeuronCore groups for inf1.6xlarge with 16 cores
num_cores = 16 # This value should be 4 on inf1.xlarge and inf1.2xlarge
os.environ['NEURON_RT_NUM_CORES'] = str(num_cores)

# Build tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased-finetuned-mrpc")
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased-finetuned-mrpc", return_dict=False)

# Setup some example inputs
sequence_0 = "The company HuggingFace is based in New York City"
sequence_1 = "Apples are especially bad for your health"
sequence_2 = "HuggingFace's headquarters are situated in Manhattan"

max_length=128
paraphrase = tokenizer.encode_plus(sequence_0, sequence_2, max_length=max_length, padding='max_length', truncation=True, return_tensors="pt")
not_paraphrase = tokenizer.encode_plus(sequence_0, sequence_1, max_length=max_length, padding='max_length', truncation=True, return_tensors="pt")

# Run the original PyTorch model on compilation exaple
paraphrase_classification_logits = model(**paraphrase)[0]

# Convert example inputs to a format that is compatible with TorchScript tracing
example_inputs_paraphrase = paraphrase['input_ids'], paraphrase['attention_mask'], paraphrase['token_type_ids']
example_inputs_not_paraphrase = not_paraphrase['input_ids'], not_paraphrase['attention_mask'], not_paraphrase['token_type_ids']

# Run torch.neuron.trace to generate a TorchScript that is optimized by AWS Neuron
model_neuron = torch.neuron.trace(model, example_inputs_paraphrase)

# Verify the TorchScript works on both example inputs
paraphrase_classification_logits_neuron = model_neuron(*example_inputs_paraphrase)
not_paraphrase_classification_logits_neuron = model_neuron(*example_inputs_not_paraphrase)

# Save the TorchScript for later use
model_neuron.save('bert_neuron.pt')

You may inspect model_neuron.graph to see which part is running on CPU versus running on the accelerator. All native aten operators in the graph will be running on CPU.

[ ]:
print(model_neuron.graph)

Deploy the AWS Neuron optimized TorchScript#

To deploy the AWS Neuron optimized TorchScript, you may choose to load the saved TorchScript from disk and skip the slow compilation.

[ ]:
# Load TorchScript back
model_neuron = torch.jit.load('bert_neuron.pt')
# Verify the TorchScript works on both example inputs
paraphrase_classification_logits_neuron = model_neuron(*example_inputs_paraphrase)
not_paraphrase_classification_logits_neuron = model_neuron(*example_inputs_not_paraphrase)
classes = ['not paraphrase', 'paraphrase']
paraphrase_prediction = paraphrase_classification_logits_neuron[0][0].argmax().item()
not_paraphrase_prediction = not_paraphrase_classification_logits_neuron[0][0].argmax().item()
print('BERT says that "{}" and "{}" are {}'.format(sequence_0, sequence_2, classes[paraphrase_prediction]))
print('BERT says that "{}" and "{}" are {}'.format(sequence_0, sequence_1, classes[not_paraphrase_prediction]))

Now let’s run the model in parallel on four cores

[ ]:
def get_input_with_padding(batch, batch_size, max_length):
    ## Reformulate the batch into three batch tensors - default batch size batches the outer dimension
    encoded = batch['encoded']
    inputs = torch.squeeze(encoded['input_ids'], 1)
    attention = torch.squeeze(encoded['attention_mask'], 1)
    token_type = torch.squeeze(encoded['token_type_ids'], 1)
    quality = list(map(int, batch['quality']))

    if inputs.size()[0] != batch_size:
        print("Input size = {} - padding".format(inputs.size()))
        remainder = batch_size - inputs.size()[0]
        zeros = torch.zeros( [remainder, max_length], dtype=torch.long )
        inputs = torch.cat( [inputs, zeros] )
        attention = torch.cat( [attention, zeros] )
        token_type = torch.cat( [token_type, zeros] )

    assert(inputs.size()[0] == batch_size and inputs.size()[1] == max_length)
    assert(attention.size()[0] == batch_size and attention.size()[1] == max_length)
    assert(token_type.size()[0] == batch_size and token_type.size()[1] == max_length)

    return (inputs, attention, token_type), quality

def count(output, quality):
    assert output.size(0) >= len(quality)
    correct_count = 0
    count = len(quality)

    batch_predictions = [ row.argmax().item() for row in output ]

    for a, b in zip(batch_predictions, quality):
        if int(a)==int(b):
            correct_count += 1

    return correct_count, count

Data parallel inference#

In the below cell, we use the data parallel approach for inference. In this approach, we load multiple models, all of them running in parallel. Each model is loaded onto a single NeuronCore. In the below implementation, we launch 16 models, thereby utilizing all the 16 cores on an inf1.6xlarge.

Note: Now if you try to decrease the num_cores in the above cells, please restart the notebook and run !sudo rmmod neuron; sudo modprobe neuron step in cell 2 to clear the Neuron cores.

Since, we can run more than 1 model concurrently, the throughput for the system goes up. To achieve maximum gain in throughput, we need to efficiently feed the models so as to keep them busy at all times. In the below setup, this is done by using a producer-consumer model. We maintain a common python queue shared across all the models. The common queue enables feeding data continuously to the models.

[ ]:
from parallel import NeuronSimpleDataParallel
from bert_benchmark_utils import BertTestDataset, BertResults
import time
import functools

max_length = 128
num_cores = 16
batch_size = 1

tsv_file="glue_mrpc_dev.tsv"

data_set = BertTestDataset( tsv_file=tsv_file, tokenizer=tokenizer, max_length=max_length )
data_loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size, shuffle=True)

#Result aggregation class (code in bert_benchmark_utils.py)
results = BertResults(batch_size, num_cores)
def result_handler(output, result_id, start, end, input_dict):
    correct_count, inference_count = count(output[0], input_dict.pop(result_id))
    elapsed = end - start
    results.add_result(correct_count, inference_count, [elapsed], [end], [start])

parallel_neuron_model = NeuronSimpleDataParallel('bert_neuron.pt', num_cores)

#Starting the inference threads
parallel_neuron_model.start_continuous_inference()

# Warm up the cores
z = torch.zeros( [batch_size, max_length], dtype=torch.long )
batch = (z, z, z)
for _ in range(num_cores*4):
    parallel_neuron_model.infer(batch, -1, None)

input_dict = {}
input_id = 0
for _ in range(30):
    for batch in data_loader:
        batch, quality = get_input_with_padding(batch, batch_size, max_length)
        input_dict[input_id] = quality
        callback_fn = functools.partial(result_handler, input_dict=input_dict)
        parallel_neuron_model.infer(batch, input_id, callback_fn)
        input_id+=1

# Stop inference
parallel_neuron_model.stop()


with open("benchmark.txt", "w") as f:
    results.report(f, window_size=1)

with open("benchmark.txt", "r") as f:
    for line in f:
        print(line)

Now recompile with a larger batch size of six sentence pairs

[ ]:
batch_size = 6

example_inputs_paraphrase = (
    torch.cat([paraphrase['input_ids']] * batch_size,0),
    torch.cat([paraphrase['attention_mask']] * batch_size,0),
    torch.cat([paraphrase['token_type_ids']] * batch_size,0)
)

# Run torch.neuron.trace to generate a TorchScript that is optimized by AWS Neuron
model_neuron_batch = torch.neuron.trace(model, example_inputs_paraphrase)

## Save the batched model
model_neuron_batch.save('bert_neuron_b{}.pt'.format(batch_size))

Rerun inference with batch 6

[ ]:
from parallel import NeuronSimpleDataParallel
from bert_benchmark_utils import BertTestDataset, BertResults
import time
import functools

max_length = 128
num_cores = 16
batch_size = 6

data_set = BertTestDataset( tsv_file=tsv_file, tokenizer=tokenizer, max_length=max_length )
data_loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size, shuffle=True)

#Result aggregation class (code in bert_benchmark_utils.py)
results = BertResults(batch_size, num_cores)
def result_handler(output, result_id, start, end, input_dict):
    correct_count, inference_count = count(output[0], input_dict.pop(result_id))
    elapsed = end - start
    results.add_result(correct_count, inference_count, [elapsed], [end], [start])

parallel_neuron_model = NeuronSimpleDataParallel('bert_neuron_b{}.pt'.format(batch_size), num_cores)

#Starting the inference threads
parallel_neuron_model.start_continuous_inference()

# Adding to the input queue to warm all cores
z = torch.zeros( [batch_size, max_length], dtype=torch.long )
batch = (z, z, z)
for _ in range(num_cores*4):
    parallel_neuron_model.infer(batch, -1, None)

input_dict = {}
input_id = 0
for _ in range(30):
    for batch in data_loader:
        batch, quality = get_input_with_padding(batch, batch_size, max_length)
        input_dict[input_id] = quality
        callback_fn = functools.partial(result_handler, input_dict=input_dict)
        parallel_neuron_model.infer(batch, input_id, callback_fn)
        input_id+=1

# Stop inference
parallel_neuron_model.stop()

with open("benchmark_b{}.txt".format(batch_size), "w") as f:
    results.report(f, window_size=1)

with open("benchmark_b{}.txt".format(batch_size), "r") as f:
    for line in f:
        print(line)