Data Parallel HuggingFace Pretrained BERT with Weight Sharing (Deduplication)#

Introduction#

In this tutorial we will compile and deploy BERT-base version of HuggingFace 🤗 Transformers BERT for Inferentia, with additional demonstration of using Weight Sharing (Deduplication) feature.

To use the Weight Sharing (Deduplication) feature, you must set the Neuron Runtime environmental variable NEURON_RT_MULTI_INSTANCE_SHARED_WEIGHTS to “TRUE” together with the core placement API (torch_neuron.experimental.neuron_cores_context()).

This Jupyter notebook should be run on an instance which is inf1.6xlarge or larger. The compile part of this tutorial requires inf1.6xlarge and not the inference itself. For simplicity we will run this tutorial on inf1.6xlarge but in real life scenario the compilation should be done on a compute instance and the deployment on inf1 instance to save costs.

Verify that this Jupyter notebook is running the Python kernel environment that was set up according to the PyTorch Installation Guide. You can select the kernel from the “Kernel -> Change Kernel” option on the top of this Jupyter notebook page.

Install Dependencies:#

This tutorial requires the following pip packages:

  • torch-neuron

  • neuron-cc[tensorflow]

  • transformers

Most of these packages will be installed when configuring your environment using the Neuron PyTorch setup guide. The additional dependencies must be installed here.

[1]:
%env TOKENIZERS_PARALLELISM=True #Supresses tokenizer warnings making errors easier to detect
!pip install --upgrade "transformers==4.6.0"

Compile the model into an AWS Neuron optimized TorchScript#

This step compiles the model into an AWS Neuron optimized TorchScript, and saves it in the filed bert_neuron.pt. This step is the same as the pretrained BERT tutorial without Shared Weights feature. We use batch 1 for simplicity.

[1]:
import tensorflow  # to workaround a protobuf version conflict issue
import torch
import torch.neuron
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
import transformers
import os
import warnings


# Build tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased-finetuned-mrpc")
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased-finetuned-mrpc", return_dict=False)

# Setup some example inputs
sequence_0 = "The company HuggingFace is based in New York City"
sequence_1 = "Apples are especially bad for your health"
sequence_2 = "HuggingFace's headquarters are situated in Manhattan"

max_length=128
paraphrase = tokenizer.encode_plus(sequence_0, sequence_2, max_length=max_length, padding='max_length', truncation=True, return_tensors="pt")
not_paraphrase = tokenizer.encode_plus(sequence_0, sequence_1, max_length=max_length, padding='max_length', truncation=True, return_tensors="pt")

# Run the original PyTorch model on compilation exaple
paraphrase_classification_logits = model(**paraphrase)[0]

# Convert example inputs to a format that is compatible with TorchScript tracing
example_inputs_paraphrase = paraphrase['input_ids'], paraphrase['attention_mask'], paraphrase['token_type_ids']
example_inputs_not_paraphrase = not_paraphrase['input_ids'], not_paraphrase['attention_mask'], not_paraphrase['token_type_ids']

# Run torch.neuron.trace to generate a TorchScript that is optimized by AWS Neuron
model_neuron = torch.neuron.trace(model, example_inputs_paraphrase)

# Verify the TorchScript works on both example inputs
paraphrase_classification_logits_neuron = model_neuron(*example_inputs_paraphrase)
not_paraphrase_classification_logits_neuron = model_neuron(*example_inputs_not_paraphrase)

# Save the TorchScript for later use
model_neuron.save('bert_neuron.pt')

Deploy the AWS Neuron optimized TorchScript#

To deploy the AWS Neuron optimized TorchScript, you may choose to load the saved TorchScript from disk and skip the slow compilation. This step is the same as the pretrained BERT tutorial without Shared Weights feature

[2]:
# Load TorchScript back
model_neuron = torch.jit.load('bert_neuron.pt')
# Verify the TorchScript works on both example inputs
paraphrase_classification_logits_neuron = model_neuron(*example_inputs_paraphrase)
not_paraphrase_classification_logits_neuron = model_neuron(*example_inputs_not_paraphrase)
classes = ['not paraphrase', 'paraphrase']
paraphrase_prediction = paraphrase_classification_logits_neuron[0][0].argmax().item()
not_paraphrase_prediction = not_paraphrase_classification_logits_neuron[0][0].argmax().item()
print('BERT says that "{}" and "{}" are {}'.format(sequence_0, sequence_2, classes[paraphrase_prediction]))
print('BERT says that "{}" and "{}" are {}'.format(sequence_0, sequence_1, classes[not_paraphrase_prediction]))

We define two helper functions to pad input and to count correct results.

[3]:
def get_input_with_padding(batch, batch_size, max_length):
    ## Reformulate the batch into three batch tensors - default batch size batches the outer dimension
    encoded = batch['encoded']
    inputs = torch.squeeze(encoded['input_ids'], 1)
    attention = torch.squeeze(encoded['attention_mask'], 1)
    token_type = torch.squeeze(encoded['token_type_ids'], 1)
    quality = list(map(int, batch['quality']))

    if inputs.size()[0] != batch_size:
        print("Input size = {} - padding".format(inputs.size()))
        remainder = batch_size - inputs.size()[0]
        zeros = torch.zeros( [remainder, max_length], dtype=torch.long )
        inputs = torch.cat( [inputs, zeros] )
        attention = torch.cat( [attention, zeros] )
        token_type = torch.cat( [token_type, zeros] )

    assert(inputs.size()[0] == batch_size and inputs.size()[1] == max_length)
    assert(attention.size()[0] == batch_size and attention.size()[1] == max_length)
    assert(token_type.size()[0] == batch_size and token_type.size()[1] == max_length)

    return (inputs, attention, token_type), quality

def count(output, quality):
    assert output.size(0) >= len(quality)
    correct_count = 0
    count = len(quality)

    batch_predictions = [ row.argmax().item() for row in output ]

    for a, b in zip(batch_predictions, quality):
        if int(a)==int(b):
            correct_count += 1

    return correct_count, count

Data parallel inference#

In the below cell, we use the data parallel approach for inference. In this approach, we load multiple models, all of them running in parallel. Each model is loaded onto a single NeuronCore via the core placement API (torch_neuron.experimental.neuron_cores_context()). We also set Neuron Runtime environment variable NEURON_RT_MULTI_INSTANCE_SHARED_WEIGHTS to “TRUE” as required to use the Weight Sharing feature.

In the below implementation, we launch 16 models, thereby utilizing all the 16 cores on an inf1.6xlarge.

Note: Now if you try to decrease the num_cores in the below cells, please restart the notebook and run !sudo rmmod neuron; sudo modprobe neuron step in cell 2 to clear the Neuron cores.

Since, we can run more than 1 model concurrently, the throughput for the system goes up. To achieve maximum gain in throughput, we need to efficiently feed the models so as to keep them busy at all times. In the below setup, we use parallel threads to feed data continuously to the models.

When running the cell below, you can monitor the Inferentia device activities by running neuron-top in another terminal. You will see that “Device Used Memory” is 1.6GB total, and the model instance loaded onto NeuronDevice 0 NeuronCore 0 uses the most device memory (272MB) while the other model instances loaded onto other NeuronCores use less device memory (92MB). This shows the effect of using Shared Weights as the device memory usage is lower. If you change NEURON_RT_MULTI_INSTANCE_SHARED_WEIGHTS to “FALSE” you will see that “Device Used Memory” is 3.2GB, and the model instances loaded onto NeuronDevice 0 NeuronCore 0 and 1 use the most device memory (360MB) while the other model instances now use 180MB each.

[5]:
from bert_benchmark_utils import BertTestDataset, BertResults
import time
import functools
import os
import torch.neuron as torch_neuron
from concurrent import futures

# Setting up NeuronCore groups for inf1.6xlarge with 16 cores
num_cores = 16 # This value should be 4 on inf1.xlarge and inf1.2xlarge
os.environ['NEURON_RT_NUM_CORES'] = str(num_cores)
os.environ['NEURON_RT_MULTI_INSTANCE_SHARED_WEIGHTS'] = 'TRUE'
#os.environ['NEURON_RT_MULTI_INSTANCE_SHARED_WEIGHTS'] = 'FALSE'

max_length = 128
num_cores = 16
batch_size = 1

tsv_file="glue_mrpc_dev.tsv"

data_set = BertTestDataset( tsv_file=tsv_file, tokenizer=tokenizer, max_length=max_length )
data_loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size, shuffle=True)

#Result aggregation class (code in bert_benchmark_utils.py)
results = BertResults(batch_size, num_cores)
def result_handler(output, result_id, start, end, input_dict):
    correct_count, inference_count = count(output[0], input_dict.pop(result_id))
    elapsed = end - start
    results.add_result(correct_count, inference_count, [elapsed], [end], [start])

with torch_neuron.experimental.neuron_cores_context(start_nc=0, nc_count=num_cores):
    model = torch.jit.load('bert_neuron.pt')

# Warm up the cores
z = torch.zeros( [batch_size, max_length], dtype=torch.long )
batch = (z, z, z)
for _ in range(num_cores*4):
    model(*batch)

# Prepare the input data
batch_list = []
for batch in data_loader:
    batch, quality = get_input_with_padding(batch, batch_size, max_length)
    batch_list.append((batch, quality))

# One thread running a model on one core
def one_thread(feed_data, quality):
    start = time.time()
    result = model(*feed_data)
    end = time.time()
    return result[0], quality, start, end

# Launch more threads than models/cores to keep them busy
processes = []
with futures.ThreadPoolExecutor(max_workers=num_cores*2) as executor:
    # extra loops to help you see activities in neuron-top
    for _ in range(10):
        for input_id, (batch, quality) in enumerate(batch_list):
            processes.append(executor.submit(one_thread, batch, quality))

results = BertResults(batch_size, num_cores)
for _ in futures.as_completed(processes):
    (output, quality, start, end) = _.result()
    correct_count, inference_count = count(output, quality)
    results.add_result(correct_count, inference_count, [start - end], [start], [end])

with open("benchmark.txt", "w") as f:
    results.report(f, window_size=1)

with open("benchmark.txt", "r") as f:
    for line in f:
        print(line)

[ ]: