Compiling and Deploying HuggingFace Pretrained BERT on Trn1 or Inf2#

Introduction#

In this tutorial we will compile and deploy a HuggingFace 🤗 Transformers BERT model for accelerated inference on Neuron. In this tutorial, we will be deploying directly on Trn1/Inf2 instances. If you are looking to deploy this model through SageMaker on Inf2 instance, please visit the Sagemaker samples repository.

This tutorial will use the bert-base-cased-finetuned-mrpc model. This model has 12 layers, 768 hidden dimensions, 12 attention heads, and 110M total parameters. The final layer is a binary classification head that has been trained on the Microsoft Research Paraphrase Corpus (mrpc). The input to the model is two sentences and the output of the model is whether or not those sentences are a paraphrase of each other.

This tutorial has the following main sections:

  1. Install dependencies

  2. Compile the BERT model

  3. Run inference on Neuron and compare results to CPU

  4. Benchmark the model using multicore inference

  5. Finding the optimal batch size

This Jupyter notebook should be run on a Trn1 instance (trn1.2xlarge or larger.) or Inf2 instance (inf2.xlarge or larger.)

Install dependencies#

The code in this tutorial is written for Jupyter Notebooks. To use Jupyter Notebook on the Neuron instance, you can use this guide.

This tutorial requires the following pip packages:

  • torch-neuronx

  • neuronx-cc

  • transformers

Most of these packages will be installed when configuring your environment using the Trn1/Inf2 setup guide. The additional dependencies must be installed here:

[ ]:
%env TOKENIZERS_PARALLELISM=True #Supresses tokenizer warnings making errors easier to detect
!pip install --upgrade transformers

Compile the model into an AWS Neuron optimized TorchScript#

In the following section, we load the BERT model and tokenizer, get a sample input, run inference on CPU, compile the model for Neuron using torch_neuronx.trace(), and save the optimized model as TorchScript.

torch_neuronx.trace() expects a tensor or tuple of tensor inputs to use for tracing, so we unpack the tokenizer output using the encode function.

The result of the trace stage will be a static executable where the operations to be run upon inference are determined during compilation. This means that when inferring, the resulting Neuron model must be executed with tensors that are the exact same shape as those provided at compilation time. If a model is given a tensor at inference time whose shape does not match the tensor given at compilation time, an error will occur.

For language models, the shape of the tokenizer tensors can vary based on the length of input sentence. We can satisfy the Neuron restriction of using a fixed shape input by padding all varying input tensors to a specified length. In a deployment scenario, the padding size should be chosen based on the maximum token length that is expected to occur for the application.

In the following section we will assume that we will receive a maximum of 128 tokens at inference time. We will pad our example inputs by using padding='max_length' and to avoid potential errors caused by creating a tensor that is larger than max_length=128, we will always tokenize using truncation=True.

[ ]:
import torch
import torch_neuronx
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import transformers


def encode(tokenizer, *inputs, max_length=128, batch_size=1):
    tokens = tokenizer.encode_plus(
        *inputs,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_tensors="pt"
    )
    return (
        torch.repeat_interleave(tokens['input_ids'], batch_size, 0),
        torch.repeat_interleave(tokens['attention_mask'], batch_size, 0),
        torch.repeat_interleave(tokens['token_type_ids'], batch_size, 0),
    )


# Create the tokenizer and model
name = "bert-base-cased-finetuned-mrpc"
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModelForSequenceClassification.from_pretrained(name, torchscript=True)

# Set up some example inputs
sequence_0 = "The company HuggingFace is based in New York City"
sequence_1 = "Apples are especially bad for your health"
sequence_2 = "HuggingFace's headquarters are situated in Manhattan"

paraphrase = encode(tokenizer, sequence_0, sequence_2)
not_paraphrase = encode(tokenizer, sequence_0, sequence_1)

# Run the original PyTorch BERT model on CPU
cpu_paraphrase_logits = model(*paraphrase)[0]
cpu_not_paraphrase_logits = model(*not_paraphrase)[0]

# Compile the model for Neuron
model_neuron = torch_neuronx.trace(model, paraphrase)

# Save the TorchScript for inference deployment
filename = 'model.pt'
torch.jit.save(model_neuron, filename)

Run inference and compare results#

In this section we load the compiled model, run inference on Neuron, and compare the CPU and Neuron outputs.

NOTE: Although this tutorial section uses one NeuronCore (and the next section uses two NeuronCores), by default each Jupyter notebook Python process will attempt to take ownership of all NeuronCores visible on the instance. For multi-process applications where each process should only use a subset of the NeuronCores on the instance you can use NEURON_RT_NUM_CORES=N or NEURON_RT_VISIBLE_CORES=< list of NeuronCore IDs > when starting the Jupyter notebook as described in NeuronCore Allocation and Model Placement for Inference.

[ ]:
# Load the TorchScript compiled model
model_neuron = torch.jit.load(filename)

# Verify the TorchScript works on both example inputs
neuron_paraphrase_logits = model_neuron(*paraphrase)[0]
neuron_not_paraphrase_logits = model_neuron(*not_paraphrase)[0]

# Compare the results
print('CPU paraphrase logits:        ', cpu_paraphrase_logits.detach().numpy())
print('Neuron paraphrase logits:    ', neuron_paraphrase_logits.detach().numpy())
print('CPU not-paraphrase logits:    ', cpu_not_paraphrase_logits.detach().numpy())
print('Neuron not-paraphrase logits: ', neuron_not_paraphrase_logits.detach().numpy())

Benchmarking#

In this section we benchmark the performance of the BERT model on Neuron. By default, models compiled with torch_neuronx will always execute on a single NeuronCore. When loading multiple models, the default behavior of the Neuron runtime is to evenly distribute models across all available NeuronCores. The runtime places models on the NeuronCore that has the fewest models loaded to it first. In the following section, we will torch.jit.load multiple instances of the model which should each be loaded onto their own NeuronCore. It is not useful to load more copies of a model than the number of NeuronCores on the instance since an individual NeuronCore can only execute one model at a time.

To ensure that we are maximizing hardware utilization, we must run inferences using multiple threads in parallel. It is nearly always recommended to use some form of threading/multiprocessing and some form of model replication since even the smallest Neuron EC2 instance has 2 NeuronCores available. Applications with no form of threading are only capable of 1 / num_neuron_cores hardware utilization which becomes especially problematic on large instances.

One way to view the hardware utilization is by executing the neuron-top application in the terminal while the benchmark is executing. If the monitor shows >90% utilization on all NeuronCores, this is a good indication that the hardware is being utilized effectively.

In this example we load two models, which utilizes all NeuronCores (2) on a trn1.2xlarge or inf2.xlarge instance. Additional models can be loaded and run in parallel on larger Trn1 or Inf2 instance sizes to increase throughput.

We define a benchmarking function that loads two optimized BERT models onto two separate NeuronCores, runs multithreaded inference, and calculates the corresponding latency and throughput.

[ ]:
import time
import concurrent.futures
import numpy as np


def benchmark(filename, example, n_models=2, n_threads=2, batches_per_thread=1000):
    """
    Record performance statistics for a serialized model and its input example.

    Arguments:
        filename: The serialized torchscript model to load for benchmarking.
        example: An example model input.
        n_models: The number of models to load.
        n_threads: The number of simultaneous threads to execute inferences on.
        batches_per_thread: The number of example batches to run per thread.

    Returns:
        A dictionary of performance statistics.
    """

    # Load models
    models = [torch.jit.load(filename) for _ in range(n_models)]

    # Warmup
    for _ in range(8):
        for model in models:
            model(*example)

    latencies = []

    # Thread task
    def task(model):
        for _ in range(batches_per_thread):
            start = time.time()
            model(*example)
            finish = time.time()
            latencies.append((finish - start) * 1000)

    # Submit tasks
    begin = time.time()
    with concurrent.futures.ThreadPoolExecutor(max_workers=n_threads) as pool:
        for i in range(n_threads):
            pool.submit(task, models[i % len(models)])
    end = time.time()

    # Compute metrics
    boundaries = [50, 95, 99]
    percentiles = {}

    for boundary in boundaries:
        name = f'latency_p{boundary}'
        percentiles[name] = np.percentile(latencies, boundary)
    duration = end - begin
    batch_size = 0
    for tensor in example:
        if batch_size == 0:
            batch_size = tensor.shape[0]
    inferences = len(latencies) * batch_size
    throughput = inferences / duration

    # Metrics
    metrics = {
        'filename': str(filename),
        'batch_size': batch_size,
        'batches': len(latencies),
        'inferences': inferences,
        'threads': n_threads,
        'models': n_models,
        'duration': duration,
        'throughput': throughput,
        **percentiles,
    }

    display(metrics)


def display(metrics):
    """
    Display the metrics produced by `benchmark` function.

    Args:
        metrics: A dictionary of performance statistics.
    """
    pad = max(map(len, metrics)) + 1
    for key, value in metrics.items():

        parts = key.split('_')
        parts = list(map(str.title, parts))
        title = ' '.join(parts) + ":"

        if isinstance(value, float):
            value = f'{value:0.3f}'

        print(f'{title :<{pad}} {value}')


# Benchmark BERT on Neuron
benchmark(filename, paraphrase)

Finding the optimal batch size#

Batch size has a direct impact on model performance. The NeuronCore architecture is optimized to maximize throughput with relatively small batch sizes. This means that a Neuron compiled model can outperform a GPU model, even if running single digit batch sizes.

As a general best practice, we recommend optimizing your model’s throughput by compiling the model with a small batch size and gradually increasing it to find the peak throughput on Neuron. To minimize latency, using batch size = 1 will nearly always be optimal. This batch size configuration is typically used for on-demand inference applications. To maximize throughput, usually 1 < batch_size < 10 is optimal. A configuration which uses a larger batch size is generally ideal for batched on-demand inference or offline batch processing.

In the following section, we compile BERT for multiple batch size inputs. We then run inference on each batch size and benchmark the performance. Notice that latency increases consistently as the batch size increases. Throughput increases as well, up until a certain point where the input size becomes too large to be efficient.

[ ]:
# Compile BERT for different batch sizes
for batch_size in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]:
    tokenizer = AutoTokenizer.from_pretrained(name)
    model = AutoModelForSequenceClassification.from_pretrained(name, torchscript=True)
    example = encode(tokenizer, sequence_0, sequence_2, batch_size=batch_size)
    model_neuron = torch_neuronx.trace(model, example)
    filename = f'model_batch_size_{batch_size}.pt'
    torch.jit.save(model_neuron, filename)
[ ]:
# Benchmark BERT for different batch sizes
for batch_size in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]:
    print('-'*50)
    example = encode(tokenizer, sequence_0, sequence_2, batch_size=batch_size)
    filename = f'model_batch_size_{batch_size}.pt'
    benchmark(filename, example)
    print()