Using NeuronCore Pipeline with PyTorch#

In this tutorial you compile a pretrained BERT base model from HuggingFace 🤗 Transformers, using the NeuronCore Pipeline feature of the AWS Neuron SDK. You benchmark model latency of the pipeline parallel mode and compare with the usual data parallel (multi-worker) deployment.

This tutorial is intended to run in an inf1.6xlarge, running the latest AWS Deep Learning AMI (DLAMI). The inf1.6xlarge instance size has AWS Inferentia chips for a total of 16 NeuronCores.

Verify that this Jupyter notebook is running the Python or Conda kernel environment that was set up according to the PyTorch Installation Guide. You can select the kernel from the “Kernel -> Change Kernel” option on the top of this Jupyter notebook page.

Note: Do not execute this tutorial using “Run -> Run all cells” option.

Install Dependencies:#

This tutorial requires the following pip packages:

  • torch-neuron

  • neuron-cc[tensorflow]

  • transformers

Most of these packages will be installed when configuring your environment using the Neuron PyTorch setup guide. The additional HuggingFace 🤗 Transformers dependency must be installed here.

[ ]:
%env TOKENIZERS_PARALLELISM=True #Supresses tokenizer warnings making errors easier to detect
!pip install --upgrade "transformers==4.6.0"

Compiling a BERT base model for a single NeuronCore#

To run a HuggingFace BERTModel on Inferentia, you only need to add a single extra line of code to the usual 🤗 Transformers PyTorch implementation, after importing the torch_neuron framework.

Add the argument return_dict=False to the BERT transformers model so it can be traced with TorchScript. TorchScript is a way to create serializable and optimizable models from PyTorch code.

Enable padding to a maximum sequence length of 128, to test the model’s performance with a realistic payload size. You can adapt this sequence length to your application’s requirement.

You can adapt the original example on the BertModel forward pass docstring according to the following cell

[ ]:
import torch
import torch_neuron
from transformers import BertTokenizer, BertModel

from joblib import Parallel, delayed
import numpy as np
from tqdm import tqdm

import os
import time


tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased',return_dict=False)

inputs = tokenizer("Hello, my dog is cute",return_tensors="pt",max_length=128,padding='max_length',truncation=True)

The one extra line required is the call to torch.neuron.trace() method. This call compiles the model and returns the forwad method of the torch nn.Model method, which you can use to run inference.

The compiled graph can be saved using the torch.jit.save function and restored using torch.jit.load function for inference on Inf1 instances. During inference, the previously compiled artifacts will be loaded into the Neuron Runtime for inference execution.

[ ]:
neuron_model = torch.neuron.trace(model,
                                  example_inputs = (inputs['input_ids'],inputs['attention_mask']),
                                  verbose=1)

Running the BERT base model on a single NeuronCore#

With the model already available in memory, you can time one execution and check for the latency on the single inference call. You will load the model into Inferentia with a single inference call. A large “wall time” is expected when you first run the next cell, running the cell twice will show the actual inference latency:

[ ]:
%%time
# The following line tests inference and should be executed on Inf1 instance family.
outputs = neuron_model(*(inputs['input_ids'],inputs['attention_mask']))

You can also check for the throughput of the single model running on a single NeuronCore.

The sequential inference test (for loop) does not measure all the performance one can achieve in an instance with multiple NeuronCores. To improve hardwar utilization you can run parallel inference requests over multiple model workers, which you’ll test in the Data Parallel Bonus Section below.

[ ]:
%%time
for _ in tqdm(range(100)):
    outputs = neuron_model(*(inputs['input_ids'],inputs['attention_mask']))

Save the compiled model for later use:

[ ]:
neuron_model.save('bert-base-uncased-neuron.pt')

Compiling a BERT base model for 16 NeuronCores#

Our next step is to compile the same model for all 16 NeuronCores available in the inf1.6xlarge and check the performance difference when running pipeline parallel inferences..

[ ]:
import torch
import torch_neuron
from transformers import BertTokenizer, BertModel

from joblib import Parallel, delayed
import numpy as np
from tqdm import tqdm

import os
import time


tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased',return_dict=False)

inputs = tokenizer("Hello, my dog is cute",return_tensors="pt",max_length=128,padding='max_length',truncation=True)

To enable pipeline mode during compilation, you need only to add the compiler flag --neuroncore-pipeline-cores and set the number of desired cores. The cell below sets up a neuroncore_pipeline_cores string, which you can set for the available number of NeuronCores on the instance: inf1.6xlarge has 16 NeuronCores in 4 Inferentia chips.

[ ]:
# Number of Cores in the Pipeline Mode
neuroncore_pipeline_cores = 16 # This string should be '4' on an inf1.xlarge

# Compiling for neuroncore-pipeline-cores='16'
neuron_pipeline_model = torch.neuron.trace(model,
                                           example_inputs = (inputs['input_ids'],inputs['attention_mask']),
                                           verbose=1,
                                           compiler_args = ['--neuroncore-pipeline-cores', str(neuroncore_pipeline_cores)]
                                          )

Running the BERT base model on 16 NeuronCores#

Next, time one execution and check for the latency on the single inference call over 16 cores. You will load the model into Inferentia with a single inference call. A large “wall time” is expected when you first run the next cell, running the cell twice will show the actual inference latency:

[ ]:
%%time
# The following line tests inference and should be executed on Inf1 instance family.
outputs = neuron_pipeline_model(*(inputs['input_ids'],inputs['attention_mask']))

Check also for the throughput of the single model running over a 16 NeuronCores.

The sequential inference test (for loop) does not measure all the performance one can achieve with Pipeline mode. As the inference runs in streaming fashion, at least 15 cores are waiting for a new call until the last one processes the first call. This results in low NeuronCore utilization. To improve hardware utilization you will require parallel inference requests, which you’ll test in the next section.

[ ]:
for _ in tqdm(range(100)):
    outputs = neuron_pipeline_model(*(inputs['input_ids'],inputs['attention_mask']))

Load Testing the Pipeline Parallel Mode#

To put the 16 NeuronCores group to test, a client has to run concurrent requests to the model. In this Notebook setup you achieve it by creating a thread pool with Joblib.Parallel, with all workers on the pool runing one inference call.

You can define a new method called inference_latency() so that you measure the amount of time each inference calls take.

[ ]:
def inference_latency(model,*inputs):
    """
    infetence_time is a simple method to return the latency of a model inference.

        Parameters:
            model: torch model onbject loaded using torch.jit.load
            inputs: model() args

        Returns:
            latency in seconds
    """
    start = time.time()
    _ = model(*inputs)
    return time.time() - start

Use tqdm to measure total throughput of your experiment, with a nice side-effect of “cool progress bar!”. The total throughput is expected to be high, so set your experiment range to a large number, here 30k inferences.

To calculate the latency statistics over the returned 30k list of latencies use numpy.qunatile() method.

[ ]:
t = tqdm(range(30000), position=0, leave=True)
latency = Parallel(n_jobs=12,prefer="threads")(delayed(inference_latency)(neuron_pipeline_model,*(inputs['input_ids'],inputs['attention_mask'])) for i in t)

p50 = np.quantile(latency[-10000:],0.50) * 1000
p95 = np.quantile(latency[-10000:],0.95) * 1000
p99 = np.quantile(latency[-10000:],0.99) * 1000
avg_throughput = t.total/t.format_dict['elapsed']
print(f'Avg Throughput: :{avg_throughput:.1f}')
print(f'50th Percentile Latency:{p50:.1f} ms')
print(f'95th Percentile Latency:{p95:.1f} ms')
print(f'99th Percentile Latency:{p99:.1f} ms')

Save compile model for later use:

[ ]:
# Save the TorchScript graph
neuron_pipeline_model.save('bert-base-uncased-neuron-pipeline.pt')

Bonus Section - Load Testing Data Parallel Mode#

[ ]:
import torch
import torch_neuron
from transformers import BertTokenizer

from joblib import Parallel, delayed
import numpy as np
from tqdm import tqdm

import os
import time

def inference_latency(model,*inputs):
    """
    infetence_time is a simple method to return the latency of a model inference.

        Parameters:
            model: torch model onbject loaded using torch.jit.load
            inputs: model() args

        Returns:
            latency in seconds
    """
    start = time.time()
    _ = model(*inputs)
    return time.time() - start

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

inputs = tokenizer("Hello, my dog is cute",return_tensors="pt",max_length=128,padding='max_length',truncation=True)

You use the 'NEURON_RT_NUM_CORES' environment variable to define how many Neuron cores to be used. Set the environment variable to the number of individual workers you want to test in parallel.

torch_neuron will load one model per NeuronCore group until it runs out of cores. At that point, if the Python process continues to spawn more model objest using torch.jit.load, torch_neuron will start stacking more than one model per core, until the Inferentia chip memory is full.

Inferentia is able to run inference over all the loaded models, but only one at a time. The Neuron Runtime takes care of dynamically switching the model context as requests come in, no extra worker process management required. Use 1 model per NeuronCore to achieve maximum performance.

The following cell creates a list with as many models as NeuronCore Groups and execute one single dummy inference to load the models into Inferentia.

[ ]:
import warnings
# Number of data parallel workers
number_of_workers=16 # This number should be 4 on an inf1.xlarge

# Setting up a data parallel group
os.environ['NEURON_RT_NUM_CORES'] = str(number_of_workers)

# Loading 'number_of_workers' amount of models in Python memory
model_list = [torch.jit.load('bert-base-uncased-neuron.pt') for _ in range(number_of_workers)]

# Dummy inference to load models to Inferentia
_ = [mod(*(inputs['input_ids'],inputs['attention_mask'])) for mod in model_list]

Adapt the call to joblib.Parallel() iterating over a concatenated version of the model_list, to run ‘round-robin’ calls to each of the model workers.

[ ]:
t = tqdm(model_list*1500,position=0, leave=True)
latency = Parallel(n_jobs=number_of_workers,prefer="threads")(delayed(inference_latency)(mod,*(inputs['input_ids'],inputs['attention_mask'])) for mod in t)

p50 = np.quantile(latency[-10000:],0.50) * 1000
p95 = np.quantile(latency[-10000:],0.95) * 1000
p99 = np.quantile(latency[-10000:],0.99) * 1000
avg_throughput = t.total/t.format_dict['elapsed']
print(f'Avg Throughput: :{avg_throughput:.1f}')
print(f'50th Percentile Latency:{p50:.1f} ms')
print(f'95th Percentile Latency:{p95:.1f} ms')
print(f'99th Percentile Latency:{p99:.1f} ms')

For this model, despite the larger number of workers, the per-worker latency increases when running a single model per core, which in turn reduces the total throughput.

This behavior may not repeat if the model memory footprint or the input payload size changes, i.e batch size > 1. We encourage you to experiment with the data parallel and pipeline parallel modes to optimize your application performance.