This document is relevant for: Inf2
, Trn1
, Trn1n
Data Parallel Inference on torch_neuronx#
Introduction#
This guide introduces torch_neuronx.DataParallel()
, a Python API that
implements data parallelism on ScriptModule
models created by the
PyTorch NeuronX Tracing API for Inference.
The following sections explain how data parallelism can improve the performance of
inference workloads on Inferentia, including how torch_neuronx.DataParallel()
uses dynamic batching to run inference on variable input sizes. It covers an
overview of the torch_neuronx.DataParallel()
module and provides a few
example data parallel applications.
Data parallel inference#
Data Parallelism is a form of parallelization across multiple devices or cores, referred to as nodes. Each node contains the same model and parameters, but data is distributed across the different nodes. By distributing the data across multiple nodes, data parallelism reduces the total execution time of large batch size inputs compared to sequential execution. Data parallelism works best for smaller models in latency sensitive applications that have large batch size requirements.
torch_neuronx.DataParallel#
To fully leverage the Inferentia hardware, we want to use all available
NeuronCores. An inf2.xlarge and inf2.8xlarge have two NeuronCores, an
inf2.24xlarge has 12 NeuronCores, and an inf2.48xlarge has 24 NeuronCores.
For maximum performance on Inferentia hardware, we can use
torch_neuronx.DataParallel()
to utilize all available NeuronCores.
torch_neuronx.DataParallel()
implements data parallelism at the module
level by replicating the Neuron model on all available NeuronCores
and distributing data across the different cores for parallelized inference.
This function is analogous to DataParallel
in PyTorch.
torch_neuronx.DataParallel()
requires PyTorch >= 1.8.
The following sections provide an overview of some of the features
of torch_neuronx.DataParallel()
that enable maximum performance on
Inferentia.
NeuronCore selection#
By default, DataParallel will try to use all NeuronCores allocated to the current process to fully saturate the Inferentia hardware for maximum performance. It is more efficient to make the batch dimension divisible by the number of NeuronCores. This will ensure that NeuronCores are not left idle during parallel inference and the Inferentia hardware is fully utilized.
In some applications, it is advantageous to use a subset of the
available NeuronCores for DataParallel inference. DataParallel has a
device_ids
argument that accepts a list of int
or 'nc:#'
that specify the NeuronCores to use for parallelization. See
Specifying NeuronCores
for an example of how to use device_ids
argument.
Batch dim#
DataParallel accepts a dim
argument that denotes the batch dimension used
to split the input data for distributed inference. By default,
DataParalell splits the inputs on dim = 0
if the dim
argument is not
specified. For applications with a non-zero batch dim, the dim
argument
can be used to specify the inference-time input batch dimension.
DataParallel with dim ! = 0 provides an
example of data parallel inference on inputs with batch dim = 2.
Dynamic batching#
Batch size has a direct impact on model performance. The Inferentia chip is optimized to run with small batch sizes. This means that a Neuron compiled model can outperform a GPU model, even if running single digit batch sizes.
As a general best practice, we recommend optimizing your model’s throughput by compiling the model with a small batch size and gradually increasing it to find the peak throughput on Inferentia.
Dynamic batching is a feature that allows you to use tensor batch sizes that the
Neuron model was not originally compiled against. This is necessary because the
underlying Inferentia hardware will always execute inferences with the batch
size used during compilation. Fixed batch size execution allows tuning the
input batch size for optimal performance. For example, batch size 1 may be
best suited for an ultra-low latency on-demand inference application, while
batch size > 1 can be used to maximize throughput for offline inferencing.
Dynamic batching is implemented by slicing large input tensors into chunks
that match the batch size used during the torch_neuronx.trace()
compilation call.
The torch_neuronx.DataParallel()
class automatically enables dynamic batching on
eligible models. This allows us to run inference in applications that have
inputs with a variable batch size without needing to recompile the model. See
Dynamic batching for an example
of how DataParallel can be used to run inference on inputs with a dynamic batch
size without needing to recompile the model.
Dynamic batching using small batch sizes can result in sub-optimal throughput because it involves slicing tensors into chunks and iteratively sending data to the hardware. Using a larger batch size at compilation time can use the Inferentia hardware more efficiently in order to maximize throughput. You can test the tradeoff between individual request latency and total throughput by fine-tuning the input batch size.
Automatic batching in the DataParallel module can be disabled using the
disable_dynamic_batching()
function as follows:
>>> model_parallel = torch_neuronx.DataParallel(model_neuron)
>>> model_parallel.disable_dynamic_batching()
If dynamic batching is disabled, the compile-time batch size must be equal to the inference-time batch size divided by the number of NeuronCores. DataParallel with dim != 0 and Dynamic batching disabled provide examples of running DataParallel inference with dynamic batching disabled.
Performance optimizations#
The DataParallel module has a num_workers
attribute that can be used to
specify the number of worker threads used for multithreaded inference. By
default, num_workers = 2 * number of NeuronCores
. This value can be
fine tuned to optimize DataParallel performance.
DataParallel has a split_size
attribute that dictates the size of the input
chunks that are distributed to each NeuronCore. By default,
split_size = max(1, input.shape[dim] // number of NeuronCores)
. This value
can be modified to optimally match the inference input chunk size with the
compile-time batch size.
Examples#
The following sections provide example usages of the
torch_neuronx.DataParallel()
module.
Default usage#
The default DataParallel use mode will replicate the model
on all available NeuronCores in the current process. The inputs will be split
on dim=0
.
import torch
import torch_neuronx
from torchvision import models
# Load the model and set it to evaluation mode
model = models.resnet50(pretrained=True)
model.eval()
# Compile with an example input
image = torch.rand([1, 3, 224, 224])
model_neuron = torch_neuronx.trace(model, image)
# Create the DataParallel module
model_parallel = torch_neuronx.DataParallel(model_neuron)
# Create a batched input
batch_size = 5
image_batched = torch.rand([batch_size, 3, 224, 224])
# Run inference with a batched input
output = model_parallel(image_batched)
Specifying NeuronCores#
The following example uses the device_ids
argument to use the first three
NeuronCores for DataParallel inference.
import torch
import torch_neuronx
from torchvision import models
# Load the model and set it to evaluation mode
model = models.resnet50(pretrained=True)
model.eval()
# Compile with an example input
image = torch.rand([1, 3, 224, 224])
model_neuron = torch_neuronx.trace(model, image)
# Create the DataParallel module, run on the first two NeuronCores
# Equivalent to model_parallel = torch.neuron.DataParallel(model_neuron, device_ids=[0, 1])
model_parallel = torch_neuronx.DataParallel(model_neuron, device_ids=['nc:0', 'nc:1'])
# Create a batched input
batch_size = 5
image_batched = torch.rand([batch_size, 3, 224, 224])
# Run inference with a batched input
output = model_parallel(image_batched)
DataParallel with dim != 0#
In this example we run DataParallel inference using two NeuronCores and
dim = 2
. Because dim != 0
, dynamic batching is not enabled.
Consequently, the DataParallel inference-time batch size must be two times the
compile-time batch size.
import torch
import torch_neuronx
# Create an example model
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
return self.conv(x) + 1
model = Model()
model.eval()
# Compile with an example input
image = torch.rand([1, 3, 8, 8])
model_neuron = torch_neuronx.trace(model, image)
# Create the DataParallel module using 2 NeuronCores and dim = 2
model_parallel = torch_neuronx.DataParallel(model_neuron, device_ids=[0, 1], dim=2)
# Create a batched input
# Note that image_batched.shape[dim] / len(device_ids) == image.shape[dim]
batch_size = 2 * 8
image_batched = torch.rand([1, 3, batch_size, 8])
# Run inference with a batched input
output = model_parallel(image_batched)
Dynamic batching#
In the following example, we use the torch_neuronx.DataParallel()
module
to run inference using several different batch sizes without recompiling the
Neuron model.
import torch
import torch_neuronx
from torchvision import models
# Load the model and set it to evaluation mode
model = models.resnet50(pretrained=True)
model.eval()
# Compile with an example input
image = torch.rand([1, 3, 224, 224])
model_neuron = torch_neuronx.trace(model, image)
# Create the DataParallel module
model_parallel = torch_neuronx.DataParallel(model_neuron)
# Create batched inputs and run inference on the same model
batch_sizes = [2, 3, 4, 5, 6]
for batch_size in batch_sizes:
image_batched = torch.rand([batch_size, 3, 224, 224])
# Run inference with a batched input
output = model_parallel(image_batched)
Dynamic batching disabled#
In the following example, we use
torch_neuronx.DataParallel.disable_dynamic_batching()
to disable dynamic
batching. We provide an example of a batch size that will not work when dynamic
batching is disabled as well as an example of a batch size that does work when
dynamic batching is disabled.
import torch
import torch_neuronx
from torchvision import models
# Load the model and set it to evaluation mode
model = models.resnet50(pretrained=True)
model.eval()
# Compile with an example input
image = torch.rand([1, 3, 224, 224])
model_neuron = torch_neuronx.trace(model, image)
# Create the DataParallel module and use 2 NeuronCores
model_parallel = torch_neuronx.DataParallel(model_neuron, device_ids=[0, 1], dim=0)
# Disable dynamic batching
model_parallel.disable_dynamic_batching()
# Create a batched input (this won't work)
batch_size = 4
image_batched = torch.rand([batch_size, 3, 224, 224])
# This will fail because dynamic batching is disabled and
# image_batched.shape[dim] / len(device_ids) != image.shape[dim]
# output = model_parallel(image_batched)
# Create a batched input (this will work)
batch_size = 2
image_batched = torch.rand([batch_size, 3, 224, 224])
# This will work because
# image_batched.shape[dim] / len(device_ids) == image.shape[dim]
output = model_parallel(image_batched)
This document is relevant for: Inf2
, Trn1
, Trn1n