This document is relevant for: Inf1
torch.neuron.DataParallel API#
The torch.neuron.DataParallel()
Python API implements data parallelism on
ScriptModule
models created by the
PyTorch-Neuron trace python API.
This function is analogous to DataParallel
in PyTorch.
The Data Parallel Inference on Torch Neuron application note provides an
overview of how torch.neuron.DataParallel()
can be used to improve
the performance of inference workloads on Inferentia.
- torch.neuron.DataParallel(model, device_ids=None, dim=0)#
Applies data parallelism by replicating the model on available NeuronCores and distributing data across the different NeuronCores for parallelized inference.
By default, DataParallel will use all available NeuronCores allocated for the current process for parallelism. DataParallel will apply parallelism on
dim=0
ifdim
is not specified.DataParallel automatically enables dynamic batching on eligible models if
dim=0
. Dynamic batching can be dsiabled usingtorch.neuron.DataParallel.disable_dynamic_batching()
. If dynamic batching is not enabled, the batch size at compilation-time must be equal to the batch size at inference-time divided by the number of NeuronCores being used. Specifically, the following must be true when dynamic batching is disabled:input.shape[dim] / len(device_ids) == compilation_input.shape[dim]
. DataParallel will throw a warning if dynamic batching cannot be enabled.DataParallel will try load all of a model’s NEFFs onto a single NeuronCore, only if all of the NEFFs can fit on a single NeuronCore. DataParallel does not currently support models that have been compiled with NeuronCore Pipeline.
torch.neuron.DataParallel()
requires PyTorch >= 1.8.Required Arguments
- Parameters:
model (ScriptModule) – Model created by the PyTorch-Neuron trace python API to be parallelized.
Optional Arguments
- Parameters:
device_ids (list) – List of
int
or'nc:#'
that specify the NeuronCores to use for parallelization (default: all NeuronCores). Refer to the device_ids note for a description of howdevice_ids
indexing works.dim (int) – Dimension along which the input tensor is scattered across NeuronCores (default
dim=0
).
Attributes
- torch.neuron.DataParallel.disable_dynamic_batching()#
Disables automatic dynamic batching on the DataParallel module. See Dynamic batching disabled for example of how DataParallel can be used with dynamic batching disabled. Use as follows:
>>> model_parallel = torch.neuron.DataParallel(model_neuron) >>> model_parallel.disable_dynamic_batching()
Note
device_ids
uses per-process NeuronCore granularity and zero-based
indexing. Per-process granularity means that each Python process “sees”
its own view of the world. Specifically, this means that device_ids
only “sees” the NeuronCores that are allocated for the current process.
Zero-based indexing means that each Python process will index its
allocated NeuronCores starting at 0, regardless of the “global” index of
the NeuronCores. Zero-based indexing makes it possible to redeploy the exact
same code unchanged in different process. This behavior is analogous to
the device_ids
argument in the PyTorch
DataParallel
function.
As an example, assume DataParallel is run on an inf1.6xlarge, which contains four Inferentia chips each of which contains four NeuronCores:
If
NEURON_RT_VISIBLE_CORES
is not set, a single process can access all 16 NeuronCores. Thus specifyingdevice_ids=["nc:0"]
will correspond to chip0:core0 anddevice_ids=["nc:14"]
will correspond to chip3:core2.However, if two processes are launched where: process 1 has
NEURON_RT_VISIBLE_CORES=0-6
and process 2 hasNEURON_RT_VISIBLE_CORES=7-15
,device_ids=["nc:14"]
cannot be specified in either process. Instead, chip3:core2 can only be accessed in process 2. Additionally, chip3:core2 is specified in process 2 withdevice_ids=["nc:7"]
. Furthermore, in process 1,device_ids=["nc:0"]
would correspond to chip0:core0; in process 2device_ids=["nc:0"]
would correspond to chip1:core3.
Examples#
The following sections provide example usages of the
torch.neuron.DataParallel()
module.
Default usage#
The default DataParallel use mode will replicate the model
on all available NeuronCores in the current process. The inputs will be split
on dim=0
.
import torch
import torch_neuron
from torchvision import models
# Load the model and set it to evaluation mode
model = models.resnet50(pretrained=True)
model.eval()
# Compile with an example input
image = torch.rand([1, 3, 224, 224])
model_neuron = torch.neuron.trace(model, image)
# Create the DataParallel module
model_parallel = torch.neuron.DataParallel(model_neuron)
# Create a batched input
batch_size = 5
image_batched = torch.rand([batch_size, 3, 224, 224])
# Run inference with a batched input
output = model_parallel(image_batched)
Specifying NeuronCores#
The following example uses the device_ids
argument to use the first three
NeuronCores for DataParallel inference.
import torch
import torch_neuron
from torchvision import models
# Load the model and set it to evaluation mode
model = models.resnet50(pretrained=True)
model.eval()
# Compile with an example input
image = torch.rand([1, 3, 224, 224])
model_neuron = torch.neuron.trace(model, image)
# Create the DataParallel module, run on the first three NeuronCores
# Equivalent to model_parallel = torch.neuron.DataParallel(model_neuron, device_ids=[0, 1, 2])
model_parallel = torch.neuron.DataParallel(model_neuron, device_ids=['nc:0', 'nc:1', 'nc:2'])
# Create a batched input
batch_size = 5
image_batched = torch.rand([batch_size, 3, 224, 224])
# Run inference with a batched input
output = model_parallel(image_batched)
DataParallel with dim != 0#
In this example we run DataParallel inference using four NeuronCores and
dim = 2
. Because dim != 0
, dynamic batching is not enabled.
Consequently, the DataParallel inference-time batch size must be four times the
compile-time batch size. DataParallel will generate a warning that dynamic
batching is disabled because dim != 0
.
import torch
import torch_neuron
# Create an example model
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
return self.conv(x) + 1
model = Model()
model.eval()
# Compile with an example input
image = torch.rand([1, 3, 8, 8])
model_neuron = torch.neuron.trace(model, image)
# Create the DataParallel module using 4 NeuronCores and dim = 2
model_parallel = torch.neuron.DataParallel(model_neuron, device_ids=[0, 1, 2, 3], dim=2)
# Create a batched input
# Note that image_batched.shape[dim] / len(device_ids) == image.shape[dim]
batch_size = 4 * 8
image_batched = torch.rand([1, 3, batch_size, 8])
# Run inference with a batched input
output = model_parallel(image_batched)
Dynamic batching#
In the following example, we use the torch.neuron.DataParallel()
module
to run inference using several different batch sizes without recompiling the
Neuron model.
import torch
import torch_neuron
from torchvision import models
# Load the model and set it to evaluation mode
model = models.resnet50(pretrained=True)
model.eval()
# Compile with an example input
image = torch.rand([1, 3, 224, 224])
model_neuron = torch.neuron.trace(model, image)
# Create the DataParallel module
model_parallel = torch.neuron.DataParallel(model_neuron)
# Create batched inputs and run inference on the same model
batch_sizes = [2, 3, 4, 5, 6]
for batch_size in batch_sizes:
image_batched = torch.rand([batch_size, 3, 224, 224])
# Run inference with a batched input
output = model_parallel(image_batched)
Dynamic batching disabled#
In the following example, we use
torch.neuron.DataParallel.disable_dynamic_batching()
to disable dynamic
batching. We provide an example of a batch size that will not work when dynamic
batching is disabled as well as an example of a batch size that does work when
dynamic batching is disabled.
import torch
import torch_neuron
from torchvision import models
# Load the model and set it to evaluation mode
model = models.resnet50(pretrained=True)
model.eval()
# Compile with an example input
image = torch.rand([1, 3, 224, 224])
model_neuron = torch.neuron.trace(model, image)
# Create the DataParallel module and use 4 NeuronCores
model_parallel = torch.neuron.DataParallel(model_neuron, device_ids=[0, 1, 2, 3], dim=0)
# Disable dynamic batching
model_parallel.disable_dynamic_batching()
# Create a batched input (this won't work)
batch_size = 8
image_batched = torch.rand([batch_size, 3, 224, 224])
# This will fail because dynamic batching is disabled and
# image_batched.shape[dim] / len(device_ids) != image.shape[dim]
# output = model_parallel(image_batched)
# Create a batched input (this will work)
batch_size = 4
image_batched = torch.rand([batch_size, 3, 224, 224])
# This will work because
# image_batched.shape[dim] / len(device_ids) == image.shape[dim]
output = model_parallel(image_batched)
Full tutorial with torch.neuron.DataParallel#
For an end-to-end tutorial that uses DataParallel, see the PyTorch Resnet Tutorial.
This document is relevant for: Inf1