This document is relevant for: Inf1

PyTorch-Neuron trace python API#

The PyTorch-Neuron trace Python API provides a method to generate PyTorch models for execution on Inferentia, which can be serialized as TorchScript. It is analogous to torch.jit.trace() function in PyTorch.

torch_neuron.trace(model, example_inputs, **kwargs)#

The torch_neuron.trace() method sends operations to the Neuron-Compiler (neuron-cc) for compilation and embeds compiled artifacts in a TorchScript graph.

Compilation can be done on any EC2 machine with sufficient memory and compute resources. c5.4xlarge or larger is recommended.

Options can be passed to Neuron compiler via the compile function. See Neuron compiler CLI Reference Guide (neuron-cc) for more information about compiler options.

This function partitions nodes into operations that are supported by Neuron and operations which are not. Operations which are not supported by Neuron are run on CPU. Graph partitioning can be controlled by the subgraph_builder_function, minimum_segment_size, and fallback parameters (See below). By default all supported operations are compiled and run on Neuron.

The compiled graph can be saved using the torch.jit.save() function and restored using torch.jit.load() function for inference on Inf1 instances. During inference, the previously compiled artifacts will be loaded into the Neuron Runtime for inference execution.

Required Arguments

Parameters
  • model (Module,callable) – The functions that that will be run with example_inputs arguments. The arguments and return types must compatible with torch.jit.trace(). When a Module is passed to torch_neuron.trace(), only the forward() method is run and traced.

  • example_inputs (tuple) – A tuple of example inputs that will be passed to the model while tracing. The resulting trace can be run with inputs of different types and shapes assuming the traced operations support those types and shapes. This parameter may also be a single torch.Tensor in which case it is automatically wrapped in a tuple.

Optional Keyword Arguments

Keyword Arguments
  • compiler_args (list[str]) – List of strings representing neuron-cc compiler arguments. Note that these arguments apply to all subgraphs generated by allowlist partitioning. For example, use compiler_args=['--neuroncore-pipeline-cores', '4'] to set number of NeuronCores per subgraph to 4. See Neuron compiler CLI Reference Guide (neuron-cc) for more information about compiler options.

  • compiler_timeout (int) – Timeout in seconds for waiting neuron-cc to complete. Exceeding this timeout will cause a subprocess.TimeoutExpired exception.

  • compiler_workdir (str) – Work directory used by neuron-cc. Useful for debugging and/or inspecting neuron-cc logs/IRs.

  • subgraph_builder_function (callable) – A function which is evaluated on each node during graph partitioning. This takes in a torch graph operator node and returns a bool value of whether it should be included in the fused Neuron graph or not. By default the partitioner selects all operators which are supported by Neuron.

  • minimum_segment_size (int) – A parameter used during partitioning. This specifies the minimum number of graph nodes which should be compiled into a Neuron graph (default= 2). If the number of nodes is smaller than this size, the operations will run on CPU.

  • single_fusion_ratio_threshold (float) – A parameter used during partitioning. During partitioning, if a single partition contains a fraction of operations greater than this threshold, only one graph partition will be compiled (default= 0.6). This is used to avoid compiling many small Neuron graphs. To force compilation of all graphs to Neuron (even when they are very small), a value of 1.0 can be used.

  • fallback (bool) – A function parameter to turn off graph partitioning. Indicates whether to attempt to fall back to CPU operations if an operation is not supported by Neuron. By default this is True. If this is set to False and an operation is not supported by Neuron, this will fail compilation and raise an AttributeError.

  • dynamic_batch_size (bool) – A flag to allow Neuron graphs to consume variable sized batches of data. Dynamic sizing is restricted to the 0th dimension of a tensor.

  • optimizations (list) – A list of Optimization passes to apply to the model.

  • separate_weights (bool) – A flag to enable compilation of models with over 1.9GB of constant parameters. By default this flag is False. If this is set to True and the compiler version is not new enough to support the flag, this will raise an NotImplementedError.

  • **kwargs

    All other keyword arguments will be forwarded directly to torch.jit.trace(). This supports flags like strict=False in order to allow dictionary outputs.

Returns

The traced ScriptModule with embedded compiled neuron sub-graphs. Operations in this module will run on Neuron unless they are not supported by Neuron or manually partitioned to run on CPU.

Note that in torch<1.8 This would return a ScriptFunction if the input was function type.

Return type

ScriptModule, ScriptFunction

class torch_neuron.Optimization#

A set of optimization passes that can be applied to the model.

FLOAT32_TO_FLOAT16#

A post-processing pass that converts all torch.float32 tensors to torch.float16 tensors. The advantage to this optimization pass is that input/output tensors will be type cast. This reduces the amount of data that will be copied to and from Inferentia hardware. The resulting traced model will accept both torch.float32 and torch.float16 inputs where the model used torch.float32 inputs during tracing. It is only beneficial to enable this optimization if the throughput of a model is highly dependent upon data transfer speed. This optimization is not recommended if the final application will use torch.float32 inputs since the torch.float16 type cast will occur on CPU during inference.

Example Usage#

Function Compilation#

import torch
import torch_neuron

def foo(x, y):
    return 2 * x + y

# Run `foo` with the provided inputs and record the tensor operations
traced_foo = torch.neuron.trace(foo, (torch.rand(3), torch.rand(3)))

# `traced_foo` can now be run with the TorchScript interpreter or saved
# and loaded in a Python-free environment
torch.jit.save(traced_foo, 'foo.pt')
traced_foo = torch.jit.load('foo.pt')

Module Compilation#

import torch
import torch_neuron
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x) + 1

n = Net()
n.eval()

inputs = torch.rand(1, 1, 3, 3)

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
neuron_forward = torch.neuron.trace(n.forward, inputs)

# Trace a module (implicitly traces `forward`) and constructs a
# `ScriptModule` with a single `forward` method
neuron_net = torch.neuron.trace(n, inputs)

Pre-Trained Model Compilation#

The following is an example usage of the compilation Python API, with default compilation arguments, using a pretrained torch.nn.Module:

import torch
import torch_neuron
from torchvision import models

# Load the model and set it to evaluation mode
model = models.resnet50(pretrained=True)
model.eval()

# Compile with an example input
image = torch.rand([1, 3, 224, 224])
model_neuron = torch.neuron.trace(model, image)

Compiling models with torch.jit.trace kwargs#

This example uses the strict=False flag to compile a model with dictionary outputs. Similarly, any other keyword argument of torch.jit.trace() can be passed directly to torch_neuron.trace() so that it is passed to the underlying trace call.

import torch
import torch_neuron
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return {'conv': self.conv(x) + 1}

model = Model()
model.eval()

inputs = torch.rand(1, 1, 3, 3)

# use the strict=False kwarg to compile a model with dictionary outputs
# the model output format does not change
model_neuron = torch.neuron.trace(model, inputs, strict=False)

Dynamic Batching#

This example uses the optional dynamic_batch_size option in order to support variable sized batches at inference time.

import torch
import torch_neuron
from torchvision import models

# Load the model and set it to evaluation mode
model = models.resnet50(pretrained=True)
model.eval()

# Compile with an example input of batch size 1
image = torch.rand([1, 3, 224, 224])
model_neuron = torch.neuron.trace(model, image, dynamic_batch_size=True)

# Execute with a batch of 7 images
batch = torch.rand([7, 3, 224, 224])
results = model_neuron(batch)

Manual Partitioning#

The following example uses the optional subgraph_builder_function parameter to ensure that only a specific convolution layer is compiled to Neuron. The remaining operations are executed on CPU.

import torch
import torch_neuron
import torch.nn as nn

class ExampleConvolutionLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x) + 1

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = ExampleConvolutionLayer()

    def forward(self, x):
        return self.layer(x) * 100

def subgraph_builder_function(node) -> bool:
    """Select if the node will be included in the Neuron graph"""

    # Node names are tuples of Module names.
    if 'ExampleConvolutionLayer' in node.name:
        return True

    # Ignore all operations not in the example convolution layer
    return False

model = Model()
model.eval()

inputs = torch.rand(1, 1, 3, 3)

# Log output shows that `aten::_convolution` and `aten::add` are compiled
# but `aten::mul` is not. This will seamlessly switch between Neuron/CPU
# execution in a single graph.
neuron_model = torch_neuron.trace(
    model,
    inputs,
    subgraph_builder_function=subgraph_builder_function
)

Separate Weights#

This example uses the optional separate_weights option in order to support compilation of models greater than 1.9GB.

import torch
import torch_neuron
from torchvision import models

# Load the model
model = models.resnet50(pretrained=True)
model.eval()

# Compile with an example input
image = torch.rand([1, 3, 224, 224])
#the models' output format does not change
model_neuron = torch.neuron.trace(model, image, separate_weights=True)

This document is relevant for: Inf1