This document is relevant for: Inf2, Trn1, Trn1n

PyTorch NeuronX DataParallel API#

The torch_neuronx.DataParallel() Python API implements data parallelism on ScriptModule models created by PyTorch NeuronX Tracing API for Inference. This function is analogous to DataParallel in PyTorch. The Data Parallel Inference on torch_neuronx application note provides an overview of how torch_neuronx.DataParallel() can be used to improve the performance of inference workloads on Inferentia.

torch_neuronx.DataParallel(model, device_ids=None, dim=0, set_dynamic_batching=True)#

Applies data parallelism by replicating the model on available NeuronCores and distributing data across the different NeuronCores for parallelized inference.

By default, DataParallel will use all available NeuronCores allocated for the current process for parallelism. DataParallel will apply parallelism on dim=0 if dim is not specified.

DataParallel automatically enables dynamic batching on eligible models if dim=0. Dynamic batching can be disabled using torch_neuronx.DataParallel.disable_dynamic_batching(), or by setting set_dynamic_batching=False when initializing the DataParallel object. If dynamic batching is not enabled, the batch size at compilation-time must be equal to the batch size at inference-time divided by the number of NeuronCores being used. Specifically, the following must be true when dynamic batching is disabled: input.shape[dim] / len(device_ids) == compilation_input.shape[dim].

torch.neuron.DataParallel() requires PyTorch >= 1.8.

Required Arguments

Parameters

model (ScriptModule) – Model created by the PyTorch NeuronX Tracing API for Inference to be parallelized.

Optional Arguments

Parameters
  • device_ids (list) – List of int or 'nc:#' that specify the NeuronCores to use for parallelization (default: all NeuronCores). Refer to the device_ids note for a description of how device_ids indexing works.

  • dim (int) – Dimension along which the input tensor is scattered across NeuronCores (default dim=0).

  • set_dynamic_batching (bool) – Whether to enable dynamic batching.

Attributes

Parameters
  • num_workers (int) – Number of worker threads used for multithreaded inference (default: 2 * number of NeuronCores).

  • split_size (int) – Size of the input chunks (default: max(1, input.shape[dim] // number of NeuronCores)).

torch.neuron.DataParallel.disable_dynamic_batching()#

Disables automatic dynamic batching on the DataParallel module. See Dynamic batching disabled for example of how DataParallel can be used with dynamic batching disabled. Use as follows:

>>> model_parallel = torch_neuronx.DataParallel(model_neuron)
>>> model_parallel.disable_dynamic_batching()

Note

device_ids uses per-process NeuronCore granularity and zero-based indexing. Per-process granularity means that each Python process “sees” its own view of the world. Specifically, this means that device_ids only “sees” the NeuronCores that are allocated for the current process. Zero-based indexing means that each Python process will index its allocated NeuronCores starting at 0, regardless of the “global” index of the NeuronCores. Zero-based indexing makes it possible to redeploy the exact same code unchanged in different process. This behavior is analogous to the device_ids argument in the PyTorch DataParallel function.

As an example, assume DataParallel is run on an inf2.48xlarge, which contains 12 Inferentia chips each of which contains two NeuronCores:

  • If NEURON_RT_VISIBLE_CORES is not set, a single process can access all 24 NeuronCores. Thus specifying device_ids=["nc:0"] will correspond to chip0:core0 and device_ids=["nc:13"] will correspond to chip6:core1.

  • However, if two processes are launched where: process 1 has NEURON_RT_VISIBLE_CORES=0-11 and process 2 has NEURON_RT_VISIBLE_CORES=12-23, device_ids=["nc:13"] cannot be specified in either process. Instead, chip6:core1 can only be accessed in process 2. Additionally, chip6:core1 is specified in process 2 with device_ids=["nc:1"]. Furthermore, in process 1, device_ids=["nc:0"] would correspond to chip0:core0; in process 2 device_ids=["nc:0"] would correspond to chip6:core0.

Examples#

The following sections provide example usages of the torch_neuronx.DataParallel() module.

Default usage#

The default DataParallel use mode will replicate the model on all available NeuronCores in the current process. The inputs will be split on dim=0.

import torch
import torch_neuronx
from torchvision import models

# Load the model and set it to evaluation mode
model = models.resnet50(pretrained=True)
model.eval()

# Compile with an example input
image = torch.rand([1, 3, 224, 224])
model_neuron = torch_neuronx.trace(model, image)

# Create the DataParallel module
model_parallel = torch_neuronx.DataParallel(model_neuron)

# Create a batched input
batch_size = 5
image_batched = torch.rand([batch_size, 3, 224, 224])

# Run inference with a batched input
output = model_parallel(image_batched)

Specifying NeuronCores#

The following example uses the device_ids argument to use the first three NeuronCores for DataParallel inference.

import torch
import torch_neuronx
from torchvision import models

# Load the model and set it to evaluation mode
model = models.resnet50(pretrained=True)
model.eval()

# Compile with an example input
image = torch.rand([1, 3, 224, 224])
model_neuron = torch_neuronx.trace(model, image)

# Create the DataParallel module, run on the first two NeuronCores
# Equivalent to model_parallel = torch.neuron.DataParallel(model_neuron, device_ids=[0, 1])
model_parallel = torch_neuronx.DataParallel(model_neuron, device_ids=['nc:0', 'nc:1'])

# Create a batched input
batch_size = 5
image_batched = torch.rand([batch_size, 3, 224, 224])

# Run inference with a batched input
output = model_parallel(image_batched)

DataParallel with dim != 0#

In this example we run DataParallel inference using two NeuronCores and dim = 2. Because dim != 0, dynamic batching is not enabled. Consequently, the DataParallel inference-time batch size must be two times the compile-time batch size.

import torch
import torch_neuronx

# Create an example model
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 3, 3)

    def forward(self, x):
        return self.conv(x) + 1

model = Model()
model.eval()

# Compile with an example input
image = torch.rand([1, 3, 8, 8])
model_neuron = torch_neuronx.trace(model, image)

# Create the DataParallel module using 2 NeuronCores and dim = 2
model_parallel = torch_neuronx.DataParallel(model_neuron, device_ids=[0, 1], dim=2)

# Create a batched input
# Note that image_batched.shape[dim] / len(device_ids) == image.shape[dim]
batch_size = 2 * 8
image_batched = torch.rand([1, 3, batch_size, 8])

# Run inference with a batched input
output = model_parallel(image_batched)

Dynamic batching#

In the following example, we use the torch_neuronx.DataParallel() module to run inference using several different batch sizes without recompiling the Neuron model.

import torch
import torch_neuronx
from torchvision import models

# Load the model and set it to evaluation mode
model = models.resnet50(pretrained=True)
model.eval()

# Compile with an example input
image = torch.rand([1, 3, 224, 224])
model_neuron = torch_neuronx.trace(model, image)

# Create the DataParallel module
model_parallel = torch_neuronx.DataParallel(model_neuron)

# Create batched inputs and run inference on the same model
batch_sizes = [2, 3, 4, 5, 6]
for batch_size in batch_sizes:
    image_batched = torch.rand([batch_size, 3, 224, 224])

    # Run inference with a batched input
    output = model_parallel(image_batched)

Dynamic batching disabled#

In the following example, we use torch_neuronx.DataParallel.disable_dynamic_batching() to disable dynamic batching. We provide an example of a batch size that will not work when dynamic batching is disabled as well as an example of a batch size that does work when dynamic batching is disabled.

import torch
import torch_neuronx
from torchvision import models

# Load the model and set it to evaluation mode
model = models.resnet50(pretrained=True)
model.eval()

# Compile with an example input
image = torch.rand([1, 3, 224, 224])
model_neuron = torch_neuronx.trace(model, image)

# Create the DataParallel module and use 2 NeuronCores
model_parallel = torch_neuronx.DataParallel(model_neuron, device_ids=[0, 1], dim=0)

# Disable dynamic batching
model_parallel.disable_dynamic_batching()

# Create a batched input (this won't work)
batch_size = 4
image_batched = torch.rand([batch_size, 3, 224, 224])

# This will fail because dynamic batching is disabled and
# image_batched.shape[dim] / len(device_ids) != image.shape[dim]
# output = model_parallel(image_batched)

# Create a batched input (this will work)
batch_size = 2
image_batched = torch.rand([batch_size, 3, 224, 224])

# This will work because
# image_batched.shape[dim] / len(device_ids) == image.shape[dim]
output = model_parallel(image_batched)

This document is relevant for: Inf2, Trn1, Trn1n