Compiling and Deploying ResNet50 on Trn1 or Inf2#

Introduction#

In this tutorial we will compile and deploy a TorchVision ResNet50 model for accelerated inference on Neuron. To get started with Jupyter Notebook on Neuron Instance you launched, please use this guide.

This tutorial will use the resnet50 model, which is primarily used for arbitrary image classification tasks.

This tutorial has the following main sections:

  1. Install dependencies

  2. Compile the ResNet model

  3. Run inference on Neuron and compare results to CPU

  4. Benchmark the model using multicore inference

  5. Finding the optimal batch size

This Jupyter notebook should be run on a Trn1 instance (trn1.2xlarge or larger.) or Inf2 instance (inf2.xlarge or larger.)

Install Dependencies#

The code in this tutorial is written for Jupyter Notebooks. To use Jupyter Notebook on the Neuron instance, you can use this guide.

This tutorial requires the following pip packages:

  • torch-neuronx

  • neuronx-cc

  • torchvision

  • Pillow

Most of these packages will be installed when configuring your environment using the Trn1 setup guide. The additional dependencies must be installed here:

[ ]:
!pip install Pillow

Compile the model into an AWS Neuron optimized TorchScript#

In the following section, we load the model, get a sample input, run inference on CPU, compile the model for Neuron using torch_neuronx.trace(), and save the optimized model as TorchScript.

torch_neuronx.trace() expects a tensor or tuple of tensor inputs to use for tracing, so we convert the input image into a tensor using the get_image function.

The result of the trace stage will be a static executable where the operations to be run upon inference are determined during compilation. This means that when inferring, the resulting Neuron model must be executed with tensors that are the exact same shape as those provided at compilation time. If a model is given a tensor at inference time whose shape does not match the tensor given at compilation time, an error will occur.

In the following section, we assume that we will receive an image shape of [1, 3, 224, 224] at inference time.

[ ]:
import os
import urllib
from PIL import Image

import torch
import torch_neuronx
from torchvision import models
from torchvision.transforms import functional


def get_image(batch_size=1, image_shape=(224, 224)):
    # Get an example input
    filename = "000000039769.jpg"
    if not os.path.exists(filename):
        url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        urllib.request.urlretrieve(url, filename)
    image = Image.open(filename).convert('RGB')
    image = functional.resize(image, (image_shape))
    image = functional.to_tensor(image)
    image = torch.unsqueeze(image, 0)
    image = torch.repeat_interleave(image, batch_size, 0)
    return (image, )


# Create the model
model = models.resnet50(pretrained=True)
model.eval()

# Get an example input
image = get_image()

# Run inference on CPU
output_cpu = model(*image)

# Compile the model
model_neuron = torch_neuronx.trace(model, image)

# Save the TorchScript for inference deployment
filename = 'model.pt'
torch.jit.save(model_neuron, filename)

Run inference and compare results#

In this section we load the compiled model, run inference on Neuron, and compare the CPU and Neuron outputs using the ImageNet classes.

[ ]:
import json

# Load the TorchScript compiled model
model_neuron = torch.jit.load(filename)

# Run inference using the Neuron model
output_neuron = model_neuron(*image)

# Compare the results
print(f"CPU tensor:    {output_cpu[0][0:10]}")
print(f"Neuron tensor: {output_neuron[0][0:10]}")

# Download and read the ImageNet classes
urllib.request.urlretrieve("https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json","imagenet_class_index.json")
with open("imagenet_class_index.json", "r") as file:
    class_id = json.load(file)
    id2label = [class_id[str(i)][1] for i in range(len(class_id))]

# Lookup and print the top-5 labels
top5_cpu = output_cpu[0].sort()[1][-5:]
top5_neuron = output_neuron[0].sort()[1][-5:]
top5_labels_cpu = [id2label[idx] for idx in top5_cpu]
top5_labels_neuron = [id2label[idx] for idx in top5_neuron]
print(f"CPU top-5 labels:    {top5_labels_cpu}")
print(f"Neuron top-5 labels: {top5_labels_neuron}")

Benchmarking#

In this section we benchmark the performance of the ResNet model on Neuron. By default, models compiled with torch_neuronx will always execute on a single NeuronCore. When loading multiple models, the default behavior of the Neuron runtime is to evenly distribute models across all available NeuronCores. The runtime places models on the NeuronCore that has the fewest models loaded to it first. In the following section, we will torch.jit.load multiple instances of the model which should each be loaded onto their own NeuronCore. It is not useful to load more copies of a model than the number of NeuronCores on the instance since an individual NeuronCore can only execute one model at a time.

To ensure that we are maximizing hardware utilization, we must run inferences using multiple threads in parallel. It is nearly always recommended to use some form of threading/multiprocessing and some form of model replication since even the smallest Neuron EC2 instance has 2 NeuronCores available. Applications with no form of threading are only capable of 1 / num_neuron_cores hardware utilization which becomes especially problematic on large instances.

One way to view the hardware utilization is by executing the neuron-top application in the terminal while the benchmark is executing. If the monitor shows >90% utilization on all NeuronCores, this is a good indication that the hardware is being utilized effectively.

In this example we load two models, which utilizes all NeuronCores (2) on a trn1.2xlarge or inf2.xlarge instance. Additional models can be loaded and run in parallel on larger Trn1 or Inf2 instance sizes to increase throughput.

We define a benchmarking function that loads two optimized ResNet models onto two separate NeuronCores, runs multithreaded inference, and calculates the corresponding latency and throughput.

[ ]:
import time
import concurrent.futures
import numpy as np


def benchmark(filename, example, n_models=2, n_threads=2, batches_per_thread=1000):
    """
    Record performance statistics for a serialized model and its input example.

    Arguments:
        filename: The serialized torchscript model to load for benchmarking.
        example: An example model input.
        n_models: The number of models to load.
        n_threads: The number of simultaneous threads to execute inferences on.
        batches_per_thread: The number of example batches to run per thread.

    Returns:
        A dictionary of performance statistics.
    """

    # Load models
    models = [torch.jit.load(filename) for _ in range(n_models)]

    # Warmup
    for _ in range(8):
        for model in models:
            model(*example)

    latencies = []

    # Thread task
    def task(model):
        for _ in range(batches_per_thread):
            start = time.time()
            model(*example)
            finish = time.time()
            latencies.append((finish - start) * 1000)

    # Submit tasks
    begin = time.time()
    with concurrent.futures.ThreadPoolExecutor(max_workers=n_threads) as pool:
        for i in range(n_threads):
            pool.submit(task, models[i % len(models)])
    end = time.time()

    # Compute metrics
    boundaries = [50, 95, 99]
    percentiles = {}

    for boundary in boundaries:
        name = f'latency_p{boundary}'
        percentiles[name] = np.percentile(latencies, boundary)
    duration = end - begin
    batch_size = 0
    for tensor in example:
        if batch_size == 0:
            batch_size = tensor.shape[0]
    inferences = len(latencies) * batch_size
    throughput = inferences / duration

    # Metrics
    metrics = {
        'filename': str(filename),
        'batch_size': batch_size,
        'batches': len(latencies),
        'inferences': inferences,
        'threads': n_threads,
        'models': n_models,
        'duration': duration,
        'throughput': throughput,
        **percentiles,
    }

    display(metrics)


def display(metrics):
    """
    Display the metrics produced by `benchmark` function.

    Args:
        metrics: A dictionary of performance statistics.
    """
    pad = max(map(len, metrics)) + 1
    for key, value in metrics.items():

        parts = key.split('_')
        parts = list(map(str.title, parts))
        title = ' '.join(parts) + ":"

        if isinstance(value, float):
            value = f'{value:0.3f}'

        print(f'{title :<{pad}} {value}')


# Benchmark ResNet on Neuron
benchmark(filename, image)

Finding the optimal batch size#

Batch size has a direct impact on model performance. The NeuronCore architecture is optimized to maximize throughput with relatively small batch sizes. This means that a Neuron compiled model can outperform a GPU model, even if running single digit batch sizes.

As a general best practice, we recommend optimizing your model’s throughput by compiling the model with a small batch size and gradually increasing it to find the peak throughput on Neuron. To minimize latency, using batch size = 1 will nearly always be optimal. This batch size configuration is typically used for on-demand inference applications. To maximize throughput, usually 1 < batch_size < 10 is optimal. A configuration which uses a larger batch size is generally ideal for batched on-demand inference or offline batch processing.

In the following section, we compile ResNet for multiple batch size inputs. We then run inference on each batch size and benchmark the performance. Notice that latency increases consistently as the batch size increases. Throughput increases as well, up until a certain point where the input size becomes too large to be efficient.

[ ]:
# Compile ResNet for different batch sizes
for batch_size in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]:
    model = models.resnet50(pretrained=True)
    model.eval()
    example = get_image(batch_size=batch_size)
    model_neuron = torch_neuronx.trace(model, example)
    filename = f'model_batch_size_{batch_size}.pt'
    torch.jit.save(model_neuron, filename)
[ ]:
# Benchmark ResNet for different batch sizes
for batch_size in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]:
    print('-'*50)
    example = get_image(batch_size=batch_size)
    filename = f'model_batch_size_{batch_size}.pt'
    benchmark(filename, example)
    print()