This document is relevant for: Inf2
, Trn1
, Trn2
Comparison of Traced Inference versus XLA Lazy Tensor Inference (torch-neuronx
)#
Introduction#
Using torch-neuronx
, there are two ways that a model can be
executed for inference:
XLA LazyTensor Inference: A model is executed on Neuron by calling
to()
to moveParameter
andTensor
data using thexm.xla_device()
. Executing operations uses torch Lazy Tensor to record, compile, and execute the graph. These are the same mechanisms used for training.(Recommended) Traced Inference: A model is traced prior to inference using the
trace()
API. This trace is similar totorch.jit.trace()
but instead creates a Neuron-specific TorchScript artifact. This artifact provides improved performance and portability compared to XLA Lazy Tensor inference.
XLA Lazy Tensor Inference Mechanics#
XLA Lazy Tensor inference uses Just-In-Time (JIT) compilation for Neuron execution.
XLA Device execution uses the built-in torch-xla
functionality with torch
Lazy Tensor to record torch operations using the xm.xla_device()
. The graph of
operations is sent to the neuronx-cc compiler upon calling
xm.mark_step()
. Finally the compiled graph is transferred to a NeuronCore
and executed in the Neuron backend.
The initial model inference will be very slow since the model binary file in the Neuron Executable File Format (NEFF) will need to be generated by the compiler. Upon each subsequent call to a model, the application will re-execute the python, rebuild the graph, and check a cache to see if an existing NEFF file is available for the given graph before attempting to recompile.
The process of recording graph operations in python can become a bottleneck for otherwise fast models. This overhead will always have an effect on performance regardless of model size but may be less noticeable on larger models. Note that this XLA Lazy Tensor execution performance may improve significantly with new torch features in the future.
Example#
import torch
import torch_neuronx
import torch_xla.core.xla_model as xm
# Create XLA device
device = xm.xla_device()
# Load example model and inputs to Neuron device
model = torch.nn.Sequential(
torch.nn.Linear(784, 120),
torch.nn.ReLU(),
torch.nn.Linear(120, 10),
torch.nn.Softmax(dim=-1),
)
model.eval()
model.to(device)
example = torch.rand((1, 784), device=device)
# Inference
with torch.no_grad():
result = model(example)
xm.mark_step() # Compilation occurs here
print(result.cpu())
The following is an example of a model that dynamically changes the sequence length and batch size of the input token ID tensor to trigger recompilations. This kind of workflow would require padding when using traced inference.
import torch
import torch_neuronx
import torch_xla.core.xla_model as xm
# Create XLA device
device = xm.xla_device()
# Load example model and inputs to Neuron device
model = torch.nn.Sequential(
torch.nn.Embedding(num_embeddings=30522, embedding_dim=512),
torch.nn.Linear(512, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 2),
torch.nn.Softmax(dim=-1),
)
model.eval()
model.to(device)
token_ids_1 = torch.tensor([
[1, 28, 748, 0],
]) # shape: [1, 4]
token_ids_2 = torch.tensor([
[1, 13087, 10439, 1990, 18912, 0],
[1, 12009, 7849, 2509, 3500, 0],
]) # shape: [2, 6]
# Inference
with torch.no_grad():
# First compilation/inference
result = model(token_ids_1)
xm.mark_step()
print(result.cpu()) # shape: [1, 4, 2]
# Recompilation occurs here since token_ids_2 is a different shape. This infer
# would have failed if the model had been traced with shape [1, 4]
result = model(token_ids_2)
xm.mark_step()
print(result.cpu()) # shape: [2, 6, 2]
Traced Inference Mechanics#
Traced inference uses Ahead-Of-Time (AOT) compilation for Neuron execution.
Similar to XLA Lazy Tensor inference, trace()
uses the operation recording
mechanisms provided by torch-xla
to build the graph structure. This graph
structure is also sent to the neuronx-cc compiler to produce a binary (NEFF)
that is executable on Neuron.
The main difference is that the call to trace()
returns a new fully
compiled graph as a TorchScript Module. Upon calling this new Module, rather
than re-executing the python, rebuilding the graph, and checking
the cache for a matching model, the new Module simply executes the precompiled
graph that was preloaded during tracing. This is a significantly
more optimized runtime since it avoids the python operator tracing, graph
building, etc.
One disadvantage of this interface is that a model will never dynamically
recompile after a trace. This means that dynamic control flow is not supported
within a function/module. Tensor input/output shapes are fixed to the shapes
passed to the trace()
API. Dynamic batching and bucketing can be used to avoid
the pitfalls of static shapes.
Example#
import torch
import torch_neuronx
# Create example model and inputs
model = torch.nn.Sequential(
torch.nn.Linear(784, 120),
torch.nn.ReLU(),
torch.nn.Linear(120, 10),
torch.nn.Softmax(dim=-1),
)
model.eval()
example = torch.rand((1, 784))
# Create fixed model trace
trace = torch_neuronx.trace(model, example)
# Inference
result = trace(example) # No recompilation. Input shapes must not change
print(result)
Traced Inference Advantages#
Traced inference should be used for nearly all deployment purposes since it provides some key advantages over XLA Lazy Tensor execution:
Reduced Overhead: There is no overhead associated with graph recording, compilation, and model loading since these steps are performed only once within the call to
trace()
. In contrast, when using XLA Lazy Tensor inference, all of these steps are performed just-in-time (with caching to improve performance).Serializable: The TorchScript Module that is produced from the
trace()
API is serializable using the normaltorch.jit.save()
function. It is able to be reloaded in an inference environment withtorch.jit.load()
. In contrast, XLA device inference does not provide a predetermined serialization format that includes the pre-compiled NEFF artifacts. These must be manually copied to an inference environment to be used.Reduced Dependencies: When using the traced TorchScript Module in an inference environment, it is no longer required to install the neuronx-cc compiler. In contrast, when using the XLA Lazy Tensor execution, an execution may require a recompile to successfully infer.
Static & Predictable: The resulting module produced by
trace()
will contain a static model that will consume a predictable amount of Neuron device memory and will never require recompilation based on input changes. In contrast, since XLA device inference performs just-in-time compilation, it can be more difficult to predict memory utilization and the compilations that may be required at inference time.C++ Usability: If the end application is an inference platform using
libtorch
, it is easy to integrate withlibtorchneuron
to load traced modules. It is not currently possible to set up an environment to use torch in C++ in conjunction with Neuron XLA Lazy Tensor execution.
Tensor Materialization During Tracing#
While tensor materialization is normal for JIT workflow, it is not expected during traced inference. When working with traced inference, developers may encounter tensor materialization, which leads to graphs being compiled based on example input tensor value and unexpected program behavior. Therefore we need to take advantage of PyTorch/XLA’s debugging flags to identify when unexpected tensor materialization happens and make appropriate code changes to avoid tensor materialization.
A common issue occurs when tensor values are evaluated during model compilation (traced inference). Consider this example:
def forward(self, tensor):
if tensor[0] == 1:
return tensor
else:
return tensor * 2
While this code can compile and run, it may lead to unexpected behavior because:
The tensor value is being accessed during tracing (
tensor[0]
)The resulting graph becomes fixed based on the tensor value available during tracing
Developers might incorrectly assume the condition will be evaluated dynamically during inference
The solution for the code above is to utilize the debugging flags below to catch the issue and modify the code
See the updated code without tensor materialization:
class TestModel(torch.nn.Module):
def __init__(self, flag=1):
super().__init__()
# the flag should be pre-determined based on the model configuration
# it should not be an input of the model during runtime
self.flag = flag
def forward(self, tensor):
if self.flag:
return tensor
else:
return tensor * 2
Debugging Flags#
To help catch tensor materialization issues, PyTorch/XLA provides two useful approaches:
Enable warning messages for tensor materialization:
import os
os.environ['PT_XLA_DEBUG_LEVEL'] = '2'
Disable graph execution to catch issues during development:
import torch_xla
torch_xla._XLAC._set_allow_execution(False)
Recommendations#
Using these flags during development can help identify potential issues early in the development cycle. The recommended approach is to:
Use
PT_XLA_DEBUG_LEVEL=2
during initial development to identify potential materialization pointsApply
_set_allow_execution(False)
when you want to ensure no tensor materialization occurs during tracingWhen you see warnings or errors related the tensor materialization, look into the code path and make appropriate changes. The example above moved the flag to the
__init__
function which does not depend on the model input during runtime.
For more detailed debugging information, refer to the XLA PyTorch on XLA Devices.
Summary#
XLA Device Inference |
Traced Inference |
|
---|---|---|
Compilation |
JIT |
AOT |
Serialization |
N/A |
|
Performance |
Slower |
Faster |
Dynamic |
Yes |
No |
C++ Usage |
No |
Yes |
This document is relevant for: Inf2
, Trn1
, Trn2