This document is relevant for: Inf2, Trn1, Trn1n

Data Parallel Inference on torch_neuronx#

Introduction#

This guide introduces torch_neuronx.DataParallel(), a Python API that implements data parallelism on ScriptModule models created by the PyTorch NeuronX Tracing API for Inference. The following sections explain how data parallelism can improve the performance of inference workloads on Inferentia, including how torch_neuronx.DataParallel() uses dynamic batching to run inference on variable input sizes. It covers an overview of the torch_neuronx.DataParallel() module and provides a few example data parallel applications.

Data parallel inference#

Data Parallelism is a form of parallelization across multiple devices or cores, referred to as nodes. Each node contains the same model and parameters, but data is distributed across the different nodes. By distributing the data across multiple nodes, data parallelism reduces the total execution time of large batch size inputs compared to sequential execution. Data parallelism works best for smaller models in latency sensitive applications that have large batch size requirements.

torch_neuronx.DataParallel#

To fully leverage the Inferentia hardware, we want to use all available NeuronCores. An inf2.xlarge and inf2.8xlarge have two NeuronCores, an inf2.24xlarge has 12 NeuronCores, and an inf2.48xlarge has 24 NeuronCores. For maximum performance on Inferentia hardware, we can use torch_neuronx.DataParallel() to utilize all available NeuronCores.

torch_neuronx.DataParallel() implements data parallelism at the module level by replicating the Neuron model on all available NeuronCores and distributing data across the different cores for parallelized inference. This function is analogous to DataParallel in PyTorch. torch_neuronx.DataParallel() requires PyTorch >= 1.8.

The following sections provide an overview of some of the features of torch_neuronx.DataParallel() that enable maximum performance on Inferentia.

NeuronCore selection#

By default, DataParallel will try to use all NeuronCores allocated to the current process to fully saturate the Inferentia hardware for maximum performance. It is more efficient to make the batch dimension divisible by the number of NeuronCores. This will ensure that NeuronCores are not left idle during parallel inference and the Inferentia hardware is fully utilized.

In some applications, it is advantageous to use a subset of the available NeuronCores for DataParallel inference. DataParallel has a device_ids argument that accepts a list of int or 'nc:#' that specify the NeuronCores to use for parallelization. See Specifying NeuronCores for an example of how to use device_ids argument.

Batch dim#

DataParallel accepts a dim argument that denotes the batch dimension used to split the input data for distributed inference. By default, DataParalell splits the inputs on dim = 0 if the dim argument is not specified. For applications with a non-zero batch dim, the dim argument can be used to specify the inference-time input batch dimension. DataParallel with dim ! = 0 provides an example of data parallel inference on inputs with batch dim = 2.

Dynamic batching#

Batch size has a direct impact on model performance. The Inferentia chip is optimized to run with small batch sizes. This means that a Neuron compiled model can outperform a GPU model, even if running single digit batch sizes.

As a general best practice, we recommend optimizing your model’s throughput by compiling the model with a small batch size and gradually increasing it to find the peak throughput on Inferentia.

Dynamic batching is a feature that allows you to use tensor batch sizes that the Neuron model was not originally compiled against. This is necessary because the underlying Inferentia hardware will always execute inferences with the batch size used during compilation. Fixed batch size execution allows tuning the input batch size for optimal performance. For example, batch size 1 may be best suited for an ultra-low latency on-demand inference application, while batch size > 1 can be used to maximize throughput for offline inferencing. Dynamic batching is implemented by slicing large input tensors into chunks that match the batch size used during the torch_neuronx.trace() compilation call.

The torch_neuronx.DataParallel() class automatically enables dynamic batching on eligible models. This allows us to run inference in applications that have inputs with a variable batch size without needing to recompile the model. See Dynamic batching for an example of how DataParallel can be used to run inference on inputs with a dynamic batch size without needing to recompile the model.

Dynamic batching using small batch sizes can result in sub-optimal throughput because it involves slicing tensors into chunks and iteratively sending data to the hardware. Using a larger batch size at compilation time can use the Inferentia hardware more efficiently in order to maximize throughput. You can test the tradeoff between individual request latency and total throughput by fine-tuning the input batch size.

Automatic batching in the DataParallel module can be disabled using the disable_dynamic_batching() function as follows:

>>> model_parallel = torch_neuronx.DataParallel(model_neuron)
>>> model_parallel.disable_dynamic_batching()

If dynamic batching is disabled, the compile-time batch size must be equal to the inference-time batch size divided by the number of NeuronCores. DataParallel with dim != 0 and Dynamic batching disabled provide examples of running DataParallel inference with dynamic batching disabled.

Performance optimizations#

The DataParallel module has a num_workers attribute that can be used to specify the number of worker threads used for multithreaded inference. By default, num_workers = 2 * number of NeuronCores. This value can be fine tuned to optimize DataParallel performance.

DataParallel has a split_size attribute that dictates the size of the input chunks that are distributed to each NeuronCore. By default, split_size = max(1, input.shape[dim] // number of NeuronCores). This value can be modified to optimally match the inference input chunk size with the compile-time batch size.

Examples#

The following sections provide example usages of the torch_neuronx.DataParallel() module.

Default usage#

The default DataParallel use mode will replicate the model on all available NeuronCores in the current process. The inputs will be split on dim=0.

import torch
import torch_neuronx
from torchvision import models

# Load the model and set it to evaluation mode
model = models.resnet50(pretrained=True)
model.eval()

# Compile with an example input
image = torch.rand([1, 3, 224, 224])
model_neuron = torch_neuronx.trace(model, image)

# Create the DataParallel module
model_parallel = torch_neuronx.DataParallel(model_neuron)

# Create a batched input
batch_size = 5
image_batched = torch.rand([batch_size, 3, 224, 224])

# Run inference with a batched input
output = model_parallel(image_batched)

Specifying NeuronCores#

The following example uses the device_ids argument to use the first three NeuronCores for DataParallel inference.

import torch
import torch_neuronx
from torchvision import models

# Load the model and set it to evaluation mode
model = models.resnet50(pretrained=True)
model.eval()

# Compile with an example input
image = torch.rand([1, 3, 224, 224])
model_neuron = torch_neuronx.trace(model, image)

# Create the DataParallel module, run on the first two NeuronCores
# Equivalent to model_parallel = torch.neuron.DataParallel(model_neuron, device_ids=[0, 1])
model_parallel = torch_neuronx.DataParallel(model_neuron, device_ids=['nc:0', 'nc:1'])

# Create a batched input
batch_size = 5
image_batched = torch.rand([batch_size, 3, 224, 224])

# Run inference with a batched input
output = model_parallel(image_batched)

DataParallel with dim != 0#

In this example we run DataParallel inference using two NeuronCores and dim = 2. Because dim != 0, dynamic batching is not enabled. Consequently, the DataParallel inference-time batch size must be two times the compile-time batch size.

import torch
import torch_neuronx

# Create an example model
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 3, 3)

    def forward(self, x):
        return self.conv(x) + 1

model = Model()
model.eval()

# Compile with an example input
image = torch.rand([1, 3, 8, 8])
model_neuron = torch_neuronx.trace(model, image)

# Create the DataParallel module using 2 NeuronCores and dim = 2
model_parallel = torch_neuronx.DataParallel(model_neuron, device_ids=[0, 1], dim=2)

# Create a batched input
# Note that image_batched.shape[dim] / len(device_ids) == image.shape[dim]
batch_size = 2 * 8
image_batched = torch.rand([1, 3, batch_size, 8])

# Run inference with a batched input
output = model_parallel(image_batched)

Dynamic batching#

In the following example, we use the torch_neuronx.DataParallel() module to run inference using several different batch sizes without recompiling the Neuron model.

import torch
import torch_neuronx
from torchvision import models

# Load the model and set it to evaluation mode
model = models.resnet50(pretrained=True)
model.eval()

# Compile with an example input
image = torch.rand([1, 3, 224, 224])
model_neuron = torch_neuronx.trace(model, image)

# Create the DataParallel module
model_parallel = torch_neuronx.DataParallel(model_neuron)

# Create batched inputs and run inference on the same model
batch_sizes = [2, 3, 4, 5, 6]
for batch_size in batch_sizes:
    image_batched = torch.rand([batch_size, 3, 224, 224])

    # Run inference with a batched input
    output = model_parallel(image_batched)

Dynamic batching disabled#

In the following example, we use torch_neuronx.DataParallel.disable_dynamic_batching() to disable dynamic batching. We provide an example of a batch size that will not work when dynamic batching is disabled as well as an example of a batch size that does work when dynamic batching is disabled.

import torch
import torch_neuronx
from torchvision import models

# Load the model and set it to evaluation mode
model = models.resnet50(pretrained=True)
model.eval()

# Compile with an example input
image = torch.rand([1, 3, 224, 224])
model_neuron = torch_neuronx.trace(model, image)

# Create the DataParallel module and use 2 NeuronCores
model_parallel = torch_neuronx.DataParallel(model_neuron, device_ids=[0, 1], dim=0)

# Disable dynamic batching
model_parallel.disable_dynamic_batching()

# Create a batched input (this won't work)
batch_size = 4
image_batched = torch.rand([batch_size, 3, 224, 224])

# This will fail because dynamic batching is disabled and
# image_batched.shape[dim] / len(device_ids) != image.shape[dim]
# output = model_parallel(image_batched)

# Create a batched input (this will work)
batch_size = 2
image_batched = torch.rand([batch_size, 3, 224, 224])

# This will work because
# image_batched.shape[dim] / len(device_ids) == image.shape[dim]
output = model_parallel(image_batched)

This document is relevant for: Inf2, Trn1, Trn1n