This document is relevant for: Inf1
Troubleshooting Guide for PyTorch Neuron (torch-neuron
)#
Patching PyTorch version 1.13 for CVEs#
PyTorch version 1.13 has the following CVEs: - CVE-2025-32434 - CVE-2024-31580 - CVE-2024-31583
To patch PyTorch version 1.13, run the following on a CPU instance with Ubuntu 22 AMI (it takes 30 minutes on a c5.4xlarge):
git clone --recursive https://github.com/pytorch/pytorch -b v1.13.1
cd pytorch
git cherry-pick b5c3a17c2c207ebefcb85043f0cf94be9b2fef81
git cherry-pick 9c7071b0e324f9fb68ab881283d6b8d388a4bcd2
wget https://github.com/user-attachments/files/22013116/patch_v113.txt
git apply patch_v113.txt
To build the pip wheel, see build steps. A condensed version is provided below.
Install Miniconda by following installation steps and run the following commands:
source ~/miniconda3/bin/activate
conda create --name conda_py39 python=3.9
conda activate conda_py39
conda install astunparse numpy==1.19.5 ninja pyyaml setuptools cmake cffi typing_extensions future six requests dataclasses
conda install mkl mkl-include# CUDA only: Add LAPACK support for the GPU if needed
conda install -c pytorch magma-cuda110 # or the magma-cuda* that matches your CUDA version from https://anaconda.org/pytorch/repo
sudo apt install cmake g++
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}
PYTORCH_BUILD_VERSION=1.13.2 PYTORCH_BUILD_NUMBER=1 python setup.py bdist_wheel
# the PyTorch pip wheel will be in dist directory
General Torch-Neuron issues#
If you see an error about “Unknown builtin op: neuron::forward_1” like below, please ensure that import line “import torch_neuron” (to register the Neuron custom operation) is in the inference script before using torch.jit.load.
Unknown builtin op: neuron::forward_1.
Could not find any similar ops to neuron::forward_1. This op may not exist or may not be currently supported in TorchScript.
torch.jit.trace issues#
The /neuron-guide/neuron-frameworks/pytorch-neuron/api-compilation-python-api.rst
uses the PyTorch torch.jit.trace()
function to generate
ScriptModule
models for execution on Inferentia. Due to that,
to execute your PyTorch model on Inferentia it must be torch-jit-traceable,
otherwise you need to make sure your model is torch-jit-traceable. You can try
modifying your underlying PyTorch model code to make it traceable. If it’s not
possible to change your model code, you can write a wrapper around your
model that makes it torch-jit-traceable to
compile it for Inferentia.
Please visit torch.jit.trace()
to review the properties that a model must
have to be torch-jit-traceable. The PyTorch-Neuron trace API
torch_neuron.trace()
accepts **kwargs
for torch.jit.trace()
.
For example, you can use the strict=False
flag to
compile models with dictionary outputs.
Compiling models with outputs that are not torch-jit-traceable#
To enable compilation of models with non torch-jit-traceable outputs, you can
use a technique that involves writing a wrapper that converts the model’s
output into a form that is torch-jit-traceable. You can then compile the
wrapped model for Inferentia using torch_neuron.trace()
.
The following example uses a wrapper to compile a model with non torch-jit-traceable outputs. This model cannot be compiled for Inferentia in its current form because it outputs a list of tuples and tensors, which is not torch-jit-traceable.
import torch
import torch_neuron
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(self, x):
a = self.conv(x) + 1
b = self.conv(x) + 2
c = self.conv(x) + 3
# An output that is a list of tuples and tensors is not torch-traceable
return [(a, b), c]
model = Model()
model.eval()
inputs = torch.rand(1, 1, 3, 3)
# Try to compile the model
model_neuron = torch.neuron.trace(model, inputs) # ERROR: This cannot be traced, we must change the output format
To compile this model for Inferentia, we can write a wrapper around the model to convert its outputs into a tuple of tensors, which is torch-jit-traceable.
class NeuronCompatibilityWrapper(nn.Module):
def __init__(self):
super(NeuronCompatibilityWrapper, self).__init__()
self.model = Model()
def forward(self, x):
out = self.model(x)
# An output that is a tuple of tuples and tensors is torch-jit-traceable
return tuple(out)
Now, we can successfully compile the model for Inferentia using the
NeuronCompatibilityWrapper
wrapper as follows:
model = NeuronCompatibilityWrapper()
model.eval()
# Compile the traceable wrapped model
model_neuron = torch.neuron.trace(model, inputs)
If the model’s outputs must be in the original form, a second wrapper can be
used to transform the outputs after compilation for Inferentia. The following
example uses the OutputFormatWrapper
wrapper to convert the compiled
model’s output back into the original form of a list of tuples and tensors.
class OutputFormatWrapper(nn.Module):
def __init__(self):
super(OutputFormatWrapper, self).__init__()
self.traceable_model = NeuronCompatibilityWrapper()
def forward(self, x):
out = self.traceable_model(x)
# Return the output in the original format of Model()
return list(out)
model = OutputFormatWrapper()
model.eval()
# Compile the traceable wrapped model
model.traceable_model = torch.neuron.trace(model.traceable_model, inputs)
Compiling a submodule in a model that is not torch-jit-traceable#
The following example shows how to compile a submodule that is part of a non
torch-jit-traceable model. In this example, the top-level model Outer
uses a dynamic flag, which is not torch-jit-traceable. However, the
submodule Inner
is torch-jit-traceable and can be compiled for
Inferentia.
import torch
import torch_neuron
import torch.nn as nn
class Inner(nn.Module) :
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(self, x):
return self.conv(x) + 1
class Outer(nn.Module):
def __init__(self):
super().__init__()
self.inner = Inner()
def forward(self, x, add_offset: bool = False):
base = self.inner(x)
if add_offset:
return base + 1
return base
model = Outer()
inputs = torch.rand(1, 1, 3, 3)
# Compile the traceable wrapped submodule
model.inner = torch.neuron.trace(model.inner, inputs)
# TorchScript the model for serialization
script = torch.jit.script(model)
torch.jit.save(script, 'model.pt')
loaded = torch.jit.load('model.pt')
Alternatively, for usage scenarios in which the model configuration is static
during inference, the dynamic flags can be hardcoded in a wrapper to make
the model torch-jit-traceable and enable compiling the entire model for Inferentia.
In this example, we assume the add_offset
flag is always
True
during inference, so we can hardcode this conditional path in the
Static
wrapper to remove the dynmaic behavior and compile the entire
model for Inferentia.
class Static(nn.Module):
def __init__(self):
super().__init__()
self.outer = Outer()
def forward(self, x):
# hardcode `add_offset=True`
output = self.outer(x, add_offset=True)
return output
model = Static()
# We can now compile the entire model because `add_offset=True` is hardcoded in the Static wrapper
model_neuron = torch.neuron.trace(model, inputs)
This document is relevant for: Inf1