This document is relevant for: Inf1
PyTorch-Neuron trace python API#
The PyTorch-Neuron trace Python API provides a method to generate
PyTorch models for execution on Inferentia, which can be serialized as
TorchScript. It is analogous to torch.jit.trace()
function in PyTorch.
- torch_neuron.trace(model, example_inputs, **kwargs)#
The
torch_neuron.trace()
method sends operations to the Neuron-Compiler (neuron-cc
) for compilation and embeds compiled artifacts in a TorchScript graph.Compilation can be done on any EC2 machine with sufficient memory and compute resources. c5.4xlarge or larger is recommended.
Options can be passed to Neuron compiler via the compile function. See Neuron compiler CLI Reference Guide (neuron-cc) for more information about compiler options.
This function partitions nodes into operations that are supported by Neuron and operations which are not. Operations which are not supported by Neuron are run on CPU. Graph partitioning can be controlled by the
subgraph_builder_function
,minimum_segment_size
, andfallback
parameters (See below). By default all supported operations are compiled and run on Neuron.The compiled graph can be saved using the
torch.jit.save()
function and restored usingtorch.jit.load()
function for inference on Inf1 instances. During inference, the previously compiled artifacts will be loaded into the Neuron Runtime for inference execution.Required Arguments
- Parameters:
model (Module,callable) – The functions that that will be run with
example_inputs
arguments. The arguments and return types must compatible withtorch.jit.trace()
. When aModule
is passed totorch_neuron.trace()
, only theforward()
method is run and traced.example_inputs (tuple) – A tuple of example inputs that will be passed to the
model
while tracing. The resulting trace can be run with inputs of different types and shapes assuming the traced operations support those types and shapes. This parameter may also be a singletorch.Tensor
in which case it is automatically wrapped in atuple
.
Optional Keyword Arguments
- Keyword Arguments:
compiler_args (list[str]) – List of strings representing
neuron-cc
compiler arguments. Note that these arguments apply to all subgraphs generated by allowlist partitioning. For example, usecompiler_args=['--neuroncore-pipeline-cores', '4']
to set number of NeuronCores per subgraph to 4. See Neuron compiler CLI Reference Guide (neuron-cc) for more information about compiler options.compiler_timeout (int) – Timeout in seconds for waiting
neuron-cc
to complete. Exceeding this timeout will cause asubprocess.TimeoutExpired
exception.compiler_workdir (str) – Work directory used by
neuron-cc
. Useful for debugging and/or inspectingneuron-cc
logs/IRs.subgraph_builder_function (callable) – A function which is evaluated on each node during graph partitioning. This takes in a torch graph operator node and returns a
bool
value of whether it should be included in the fused Neuron graph or not. By default the partitioner selects all operators which are supported by Neuron.minimum_segment_size (int) – A parameter used during partitioning. This specifies the minimum number of graph nodes which should be compiled into a Neuron graph (default=
2
). If the number of nodes is smaller than this size, the operations will run on CPU.single_fusion_ratio_threshold (float) – A parameter used during partitioning. During partitioning, if a single partition contains a fraction of operations greater than this threshold, only one graph partition will be compiled (default=
0.6
). This is used to avoid compiling many small Neuron graphs. To force compilation of all graphs to Neuron (even when they are very small), a value of1.0
can be used.fallback (bool) – A function parameter to turn off graph partitioning. Indicates whether to attempt to fall back to CPU operations if an operation is not supported by Neuron. By default this is
True
. If this is set toFalse
and an operation is not supported by Neuron, this will fail compilation and raise anAttributeError
.dynamic_batch_size (bool) – A flag to allow Neuron graphs to consume variable sized batches of data. Dynamic sizing is restricted to the 0th dimension of a tensor.
optimizations (list) – A list of
Optimization
passes to apply to the model.separate_weights (bool) – A flag to enable compilation of models with over 1.9GB of constant parameters. By default this flag is
False
. If this is set toTrue
and the compiler version is not new enough to support the flag, this will raise anNotImplementedError
.**kwargs –
All other keyword arguments will be forwarded directly to
torch.jit.trace()
. This supports flags likestrict=False
in order to allow dictionary outputs.
- Returns:
The traced
ScriptModule
with embedded compiled neuron sub-graphs. Operations in this module will run on Neuron unless they are not supported by Neuron or manually partitioned to run on CPU.Note that in
torch<1.8
This would return aScriptFunction
if the input was function type.- Return type:
ScriptModule, ScriptFunction
- class torch_neuron.Optimization#
A set of optimization passes that can be applied to the model.
- FLOAT32_TO_FLOAT16#
A post-processing pass that converts all
torch.float32
tensors totorch.float16
tensors. The advantage to this optimization pass is that input/output tensors will be type cast. This reduces the amount of data that will be copied to and from Inferentia hardware. The resulting traced model will accept bothtorch.float32
andtorch.float16
inputs where the model usedtorch.float32
inputs during tracing. It is only beneficial to enable this optimization if the throughput of a model is highly dependent upon data transfer speed. This optimization is not recommended if the final application will usetorch.float32
inputs since thetorch.float16
type cast will occur on CPU during inference.
Example Usage#
Function Compilation#
import torch
import torch_neuron
def foo(x, y):
return 2 * x + y
# Run `foo` with the provided inputs and record the tensor operations
traced_foo = torch.neuron.trace(foo, (torch.rand(3), torch.rand(3)))
# `traced_foo` can now be run with the TorchScript interpreter or saved
# and loaded in a Python-free environment
torch.jit.save(traced_foo, 'foo.pt')
traced_foo = torch.jit.load('foo.pt')
Module Compilation#
import torch
import torch_neuron
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(self, x):
return self.conv(x) + 1
n = Net()
n.eval()
inputs = torch.rand(1, 1, 3, 3)
# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
neuron_forward = torch.neuron.trace(n.forward, inputs)
# Trace a module (implicitly traces `forward`) and constructs a
# `ScriptModule` with a single `forward` method
neuron_net = torch.neuron.trace(n, inputs)
Pre-Trained Model Compilation#
The following is an example usage of the compilation Python API, with
default compilation arguments, using a pretrained torch.nn.Module
:
import torch
import torch_neuron
from torchvision import models
# Load the model and set it to evaluation mode
model = models.resnet50(pretrained=True)
model.eval()
# Compile with an example input
image = torch.rand([1, 3, 224, 224])
model_neuron = torch.neuron.trace(model, image)
Compiling models with torch.jit.trace kwargs#
This example uses the strict=False
flag to compile a model with
dictionary outputs. Similarly, any other keyword argument of
torch.jit.trace()
can be passed directly to
torch_neuron.trace()
so that it is passed to the underlying trace call.
import torch
import torch_neuron
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(self, x):
return {'conv': self.conv(x) + 1}
model = Model()
model.eval()
inputs = torch.rand(1, 1, 3, 3)
# use the strict=False kwarg to compile a model with dictionary outputs
# the model output format does not change
model_neuron = torch.neuron.trace(model, inputs, strict=False)
Dynamic Batching#
This example uses the optional dynamic_batch_size
option in order to
support variable sized batches at inference time.
import torch
import torch_neuron
from torchvision import models
# Load the model and set it to evaluation mode
model = models.resnet50(pretrained=True)
model.eval()
# Compile with an example input of batch size 1
image = torch.rand([1, 3, 224, 224])
model_neuron = torch.neuron.trace(model, image, dynamic_batch_size=True)
# Execute with a batch of 7 images
batch = torch.rand([7, 3, 224, 224])
results = model_neuron(batch)
Manual Partitioning#
The following example uses the optional subgraph_builder_function
parameter to ensure that only a specific convolution layer is compiled to
Neuron. The remaining operations are executed on CPU.
import torch
import torch_neuron
import torch.nn as nn
class ExampleConvolutionLayer(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(self, x):
return self.conv(x) + 1
class Model(nn.Module):
def __init__(self):
super().__init__()
self.layer = ExampleConvolutionLayer()
def forward(self, x):
return self.layer(x) * 100
def subgraph_builder_function(node) -> bool:
"""Select if the node will be included in the Neuron graph"""
# Node names are tuples of Module names.
if 'ExampleConvolutionLayer' in node.name:
return True
# Ignore all operations not in the example convolution layer
return False
model = Model()
model.eval()
inputs = torch.rand(1, 1, 3, 3)
# Log output shows that `aten::_convolution` and `aten::add` are compiled
# but `aten::mul` is not. This will seamlessly switch between Neuron/CPU
# execution in a single graph.
neuron_model = torch_neuron.trace(
model,
inputs,
subgraph_builder_function=subgraph_builder_function
)
Separate Weights#
This example uses the optional separate_weights
option in order to
support compilation of models greater than 1.9GB.
import torch
import torch_neuron
from torchvision import models
# Load the model
model = models.resnet50(pretrained=True)
model.eval()
# Compile with an example input
image = torch.rand([1, 3, 224, 224])
#the models' output format does not change
model_neuron = torch.neuron.trace(model, image, separate_weights=True)
This document is relevant for: Inf1