ResNet50 model for Inferentia#

Introduction:#

In this tutorial we will compile and deploy a ResNet50 model for inference on Inferentia.

This Jupyter notebook should run on an inf1.6xlarge instance. The inference part of this tutorial requires an inf1 instance, not the compilation stage. For simplicity we will run this tutorial on an inf1.6xlarge, but in real life scenarios the compilation should be done on a compute instance and the deployment on an inf1 instance to save costs.

In this tutorial we provide three main sections:

  1. Compile the ResNet50 model and infer with a batch size of 1

  2. Run the same compiled model on multiple NeuronCores using torch.neuron.DataParallel and dynamic batching

  3. Compile the ResNet50 model with a batch size of 5 and run it on multiple NeuronCores using torch.neuron.DataParallel for optimal performance on Inferentia

Verify that this Jupyter notebook is running the Python kernel environment that was set up according to the PyTorch Installation Guide. You can select the kernel from the “Kernel -> Change Kernel” option on the top of this Jupyter notebook page.

Install Dependencies:#

This tutorial requires the following pip packages:

  • torch>=1.8

  • torch-neuron

  • torchvision

  • neuron-cc[tensorflow]

These will be installed by default when configuring your environment using the Neuron PyTorch setup guide.

Compile model for Neuron#

The following step will compile the ResNet50 model for Inferentia. This will take a few minutes. At the end of script execution, the compiled model is saved as resnet50_neuron.pt in your local directory

[ ]:
import torch
from torchvision import models, transforms, datasets
import torch_neuron

# Create an example input for compilation
image = torch.zeros([1, 3, 224, 224], dtype=torch.float32)

# Load a pretrained ResNet50 model
model = models.resnet50(pretrained=True)

# Tell the model we are using it for evaluation (not training)
model.eval()

# Analyze the model - this will show operator support and operator count
torch.neuron.analyze_model(model, example_inputs=[image])

# Compile the model using torch.neuron.trace to create a Neuron model
# that that is optimized for the Inferentia hardware
model_neuron = torch.neuron.trace(model, example_inputs=[image])

# The output of the compilation step will report the percentage of operators that
# are compiled to Neuron, for example:
#
# INFO:Neuron:The neuron partitioner created 1 sub-graphs
# INFO:Neuron:Neuron successfully compiled 1 sub-graphs, Total fused subgraphs = 1, Percent of model sub-graphs successfully compiled = 100.0%
#
# We will also be warned if there are operators that are not placed on the Inferentia hardware

# Save the compiled model
model_neuron.save("resnet50_neuron.pt")

Run inference on Inferentia#

We can use the compiled Neuron model to run inference on Inferentia.

In the following example, we preprocess a sample image for inference using the CPU model and Neuron model. We compare the predicted labels from the CPU model and Neuron model to verify that they are the same.

Important: Do not perform inference with a Neuron traced model on a non-Neuron supported instance, as the results will not be calculated properly.

Define a preprocessing function#

We define a basic image preprocessing function that loads a sample image and labels, normalizes and batches the image, and transforms the image into a tensor for inference using the compiled Neuron model.

[ ]:
import json
import os
from urllib import request

# Create an image directory containing a sample image of a small kitten
os.makedirs("./torch_neuron_test/images", exist_ok=True)
request.urlretrieve("https://raw.githubusercontent.com/awslabs/mxnet-model-server/master/docs/images/kitten_small.jpg",
                    "./torch_neuron_test/images/kitten_small.jpg")

# Fetch labels to output the top classifications
request.urlretrieve("https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json","imagenet_class_index.json")
idx2label = []

# Read the labels and create a list to hold them for classification
with open("imagenet_class_index.json", "r") as read_file:
    class_idx = json.load(read_file)
    idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]
[ ]:
import numpy as np

def preprocess(batch_size=1, num_neuron_cores=1):
    # Define a normalization function using the ImageNet mean and standard deviation
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])

    # Resize the sample image to [1, 3, 224, 224], normalize it, and turn it into a tensor
    eval_dataset = datasets.ImageFolder(
        os.path.dirname("./torch_neuron_test/"),
        transforms.Compose([
        transforms.Resize([224, 224]),
        transforms.ToTensor(),
        normalize,
        ])
    )
    image, _ = eval_dataset[0]
    image = torch.tensor(image.numpy()[np.newaxis, ...])

    # Create a "batched" image with enough images to go on each of the available NeuronCores
    # batch_size is the per-core batch size
    # num_neuron_cores is the number of NeuronCores being used
    batch_image = image
    for i in range(batch_size * num_neuron_cores - 1):
        batch_image = torch.cat([batch_image, image], 0)

    return batch_image

Run inference using the Neuron model#

We import the necessary python modules, load the torch-neuron compiled model, and run inference on Inferentia.

By default, the Neuron model will run on a single NeuronCore. In the next section, we will see how to run the Neuron model on multiple NeuronCores to fully saturate our hardware for optimal performance on Inferentia.

[ ]:
import torch
from torchvision import models, transforms, datasets
import torch_neuron

# Get a sample image
image = preprocess()

# Run inference using the CPU model
output_cpu = model(image)

# Load the compiled Neuron model
model_neuron = torch.jit.load('resnet50_neuron.pt')

# Run inference using the Neuron model
output_neuron = model_neuron(image)

# Verify that the CPU and Neuron predictions are the same by comparing
# the top-5 results
top5_cpu = output_cpu[0].sort()[1][-5:]
top5_neuron = output_neuron[0].sort()[1][-5:]

# Lookup and print the top-5 labels
top5_labels_cpu = [idx2label[idx] for idx in top5_cpu]
top5_labels_neuron = [idx2label[idx] for idx in top5_neuron]
print("CPU top-5 labels: {}".format(top5_labels_cpu))
print("Neuron top-5 labels: {}".format(top5_labels_neuron))

Run Inference using torch.neuron.DataParallel#

To fully leverage the Inferentia hardware we want to use all avaialable NeuronCores. An inf1.xlarge and inf1.2xlarge have four NeuronCores, an inf1.6xlarge has 16 NeuronCores, and an inf1.24xlarge has 64 NeuronCores. For maximum performance on Inferentia hardware, we can use torch.neuron.DataParallel to utilize all available NeuronCores.

torch.neuron.DataParallel implements data parallelism at the module level by duplicating the Neuron model on all available NeuronCores and distributing data across the different cores for parallelized inference.

In the following section, we will run inference using the torch.neuron.DataParallel module to fully saturate the Inferentia hardware. We benchmark the model to collect throughput and latency statistics.

Note: torch.neuron.DataParallel is new with Neuron 1.16.0. Please ensure you are using the latest Neuron package to run the following sections.

Define a benchmarking function#

We create a function that handles benchmarking the Neuron model to collect throughput and latency metrics.

[ ]:
from time import time

def benchmark(model, image):
    print('Input image shape is {}'.format(list(image.shape)))

    # The first inference loads the model so exclude it from timing
    results = model(image)

    # Collect throughput and latency metrics
    latency = []
    throughput = []

    # Run inference for 100 iterations and calculate metrics
    num_infers = 100
    for _ in range(num_infers):
        delta_start = time()
        results = model(image)
        delta = time() - delta_start
        latency.append(delta)
        throughput.append(image.size(0)/delta)

    # Calculate and print the model throughput and latency
    print("Avg. Throughput: {:.0f}, Max Throughput: {:.0f}".format(np.mean(throughput), np.max(throughput)))
    print("Latency P50: {:.0f}".format(np.percentile(latency, 50)*1000.0))
    print("Latency P90: {:.0f}".format(np.percentile(latency, 90)*1000.0))
    print("Latency P95: {:.0f}".format(np.percentile(latency, 95)*1000.0))
    print("Latency P99: {:.0f}\n".format(np.percentile(latency, 99)*1000.0))

Run Inference using torch.neuron.DataParallel#

We create the torch.neuron.DataParallel module using the compiled Neuron model, get a sample image, and benchmark the parallelized model on Neuron.

[ ]:
# Create a torch.neuron.DataParallel module using the compiled Neuron model
# By default, torch.neuron.DataParallel will use four cores on an inf1.xlarge
# or inf1.2xlarge, 16 cores on an inf1.6xlarge, and 24 cores on an inf1.24xlarge
model_neuron_parallel = torch.neuron.DataParallel(model_neuron)

# Get sample image with batch size=1 per NeuronCore
batch_size = 1

# For an inf1.xlarge or inf1.2xlarge, set num_neuron_cores = 4
num_neuron_cores = 16

image = preprocess(batch_size=batch_size, num_neuron_cores=num_neuron_cores)

# Benchmark the model
benchmark(model_neuron_parallel, image)

Run inference with dynamic batch sizes#

Batch size has a direct impact on model performance. The Inferentia chip is optimized to run with small batch sizes. This means that a Neuron compiled model can outperform a GPU model, even if running single digit batch sizes.

As a general best practice, we recommend optimizing your model’s throughput by compiling the model with a small batch size and gradually increasing it to find the peak throughput on Inferentia.

Dynamic batching is a feature that allows you to use tensor batch sizes that the Neuron model was not originally compiled against. This is necessary because the underlying Inferentia hardware will always execute inferences with the batch size used during compilation. Fixed batch size execution allows tuning the input batch size for optimal performance. For example, batch size 1 may be best suited for an ultra-low latency on-demand inference application, while batch size > 1 can be used to maximize throughput for offline inferencing. Dynamic batching is implemented by slicing large input tensors into chunks that match the batch size used during the torch.neuron.trace compilation call.

The torch.neuron.DataParallel class automatically enables dynamic batching on eligible models. This allows us to run inference in applications that have inputs with a variable batch size without needing to recompile the model.

In the following example, we use the same torch.neuron.DataParallel module to run inference using several different batch sizes. Notice that latency increases consistently as the batch size increases. Throughput increases as well, up until a certain point where the input size becomes too large to be efficient.

[ ]:
# using the same DataParallel model_neuron_parallel model, we can run
# inference on inputs with a variable batch size without recompiling
batch_sizes = [2, 3, 4, 5, 6, 7]
for batch_size in batch_sizes:
    print('Batch size: {}'.format(batch_size))
    image = preprocess(batch_size=batch_size, num_neuron_cores=num_neuron_cores)

    # Benchmark the model for each input batch size
    benchmark(model_neuron_parallel, image)

Compile and Infer with different batch sizes on multiple NeuronCores#

Dynamic batching using small batch sizes can result in sub-optimal throughput because it involves slicing tensors into chunks and iteratively sending data to the hardware. Using a larger batch size at compilation time can use the Inferentia hardware more efficiently in order to maximize throughput. You can test the tradeoff between individual request latency and total throughput by fine-tuning the input batch size.

In the following example, we recompile our model using a batch size of 5 and run the model using torch.neuron.DataParallel to fully saturate our Inferentia hardware for optimal performance.

[ ]:
# Create an input with batch size 5 for compilation
batch_size = 5
image = torch.zeros([batch_size, 3, 224, 224], dtype=torch.float32)

# Recompile the ResNet50 model for inference with batch size 5
model_neuron = torch.neuron.trace(model, example_inputs=[image])

# Export to saved model
model_neuron.save("resnet50_neuron_b{}.pt".format(batch_size))

Run inference with batch size of 5 using the Neuron model compiled for a batch size of 5.

[ ]:
batch_size = 5

# Load compiled Neuron model
model_neuron = torch.jit.load("resnet50_neuron_b{}.pt".format(batch_size))

# Create DataParallel model
model_neuron_parallel = torch.neuron.DataParallel(model_neuron)

# Get sample image with batch size=5
image = preprocess(batch_size=batch_size, num_neuron_cores=num_neuron_cores)

# Benchmark the model
benchmark(model_neuron_parallel, image)

You can experiment with different batch size values to see what gives the best overall throughput on Inferentia.