.. _trace-vs-xla-lazytensor: Comparison of Traced Inference versus XLA |LazyTensor| Inference (``torch-neuronx``) ===================================================================================== .. contents:: Table of contents :local: :depth: 1 Introduction ------------ Using ``torch-neuronx``, there are two ways that a model can be executed for inference: - **XLA LazyTensor Inference**: A model is executed on Neuron by calling :meth:`~torch.Tensor.to` to move :class:`~torch.nn.parameter.Parameter` and :class:`~torch.Tensor` data using the |device|. Executing operations uses torch |LazyTensor| to record, compile, and execute the graph. These are the same mechanisms used for :ref:`training `. - **(Recommended) Traced Inference**: A model is traced prior to inference using the |trace| API. This trace is similar to :func:`torch.jit.trace` but instead creates a Neuron-specific `TorchScript`_ artifact. This artifact provides improved performance and portability compared to XLA |LazyTensor| inference. .. _xla_lazytensor: XLA Lazy Tensor Inference Mechanics ----------------------------------- XLA |LazyTensor| inference uses Just-In-Time (JIT) compilation for Neuron execution. XLA Device execution uses the built-in ``torch-xla`` functionality with torch |LazyTensor| to record torch operations using the |device|. The graph of operations is sent to the |neuronx-cc| compiler upon calling ``xm.mark_step()``. Finally the compiled graph is transferred to a NeuronCore and executed in the Neuron backend. The initial model inference will be very slow since the model binary file in the Neuron Executable File Format (NEFF) will need to be generated by the compiler. Upon each subsequent call to a model, the application will re-execute the python, rebuild the graph, and check a cache to see if an existing NEFF file is available for the given graph before attempting to recompile. The process of recording graph operations in python can become a bottleneck for otherwise fast models. This overhead will always have an effect on performance regardless of model size but may be less noticeable on larger models. Note that this XLA |LazyTensor| execution performance may improve significantly with new torch features in the future. Example ~~~~~~~ .. tab-set:: .. tab-item:: Fixed Shape Example .. code-block:: python import torch import torch_neuronx import torch_xla.core.xla_model as xm # Create XLA device device = xm.xla_device() # Load example model and inputs to Neuron device model = torch.nn.Sequential( torch.nn.Linear(784, 120), torch.nn.ReLU(), torch.nn.Linear(120, 10), torch.nn.Softmax(dim=-1), ) model.eval() model.to(device) example = torch.rand((1, 784), device=device) # Inference with torch.no_grad(): result = model(example) xm.mark_step() # Compilation occurs here print(result.cpu()) .. tab-item:: Dynamic Shape Example The following is an example of a model that dynamically changes the sequence length and batch size of the input token ID tensor to trigger recompilations. This kind of workflow would require padding when using traced inference. .. code-block:: python import torch import torch_neuronx import torch_xla.core.xla_model as xm # Create XLA device device = xm.xla_device() # Load example model and inputs to Neuron device model = torch.nn.Sequential( torch.nn.Embedding(num_embeddings=30522, embedding_dim=512), torch.nn.Linear(512, 128), torch.nn.ReLU(), torch.nn.Linear(128, 2), torch.nn.Softmax(dim=-1), ) model.eval() model.to(device) token_ids_1 = torch.tensor([ [1, 28, 748, 0], ]) # shape: [1, 4] token_ids_2 = torch.tensor([ [1, 13087, 10439, 1990, 18912, 0], [1, 12009, 7849, 2509, 3500, 0], ]) # shape: [2, 6] # Inference with torch.no_grad(): # First compilation/inference result = model(token_ids_1) xm.mark_step() print(result.cpu()) # shape: [1, 4, 2] # Recompilation occurs here since token_ids_2 is a different shape. This infer # would have failed if the model had been traced with shape [1, 4] result = model(token_ids_2) xm.mark_step() print(result.cpu()) # shape: [2, 6, 2] Traced Inference Mechanics -------------------------- Traced inference uses Ahead-Of-Time (AOT) compilation for Neuron execution. Similar to XLA |LazyTensor| inference, |trace| uses the operation recording mechanisms provided by ``torch-xla`` to build the graph structure. This graph structure is also sent to the |neuronx-cc| compiler to produce a binary (NEFF) that is executable on Neuron. The main difference is that the call to |trace| returns a *new* fully compiled graph as a `TorchScript`_ Module. Upon calling this new Module, rather than re-executing the python, rebuilding the graph, and checking the cache for a matching model, the new Module simply executes the precompiled graph that was preloaded during tracing. This is a significantly more optimized runtime since it avoids the python operator tracing, graph building, etc. One disadvantage of this interface is that a model will never dynamically recompile after a trace. This means that dynamic control flow is not supported within a function/module. Tensor input/output shapes are fixed to the shapes passed to the |trace| API. Dynamic batching and bucketing can be used to avoid the pitfalls of static shapes. Example ~~~~~~~ .. code-block:: python import torch import torch_neuronx # Create example model and inputs model = torch.nn.Sequential( torch.nn.Linear(784, 120), torch.nn.ReLU(), torch.nn.Linear(120, 10), torch.nn.Softmax(dim=-1), ) model.eval() example = torch.rand((1, 784)) # Create fixed model trace trace = torch_neuronx.trace(model, example) # Inference result = trace(example) # No recompilation. Input shapes must not change print(result) Traced Inference Advantages --------------------------- Traced inference should be used for nearly all deployment purposes since it provides some key advantages over XLA |LazyTensor| execution: - **Reduced Overhead**: There is no overhead associated with graph recording, compilation, and model loading since these steps are performed only once within the call to |trace|. In contrast, when using XLA |LazyTensor| inference, all of these steps are performed just-in-time (with caching to improve performance). - **Serializable**: The TorchScript Module that is produced from the |trace| API is serializable using the normal :func:`torch.jit.save` function. It is able to be reloaded in an inference environment with :func:`torch.jit.load`. In contrast, XLA device inference does not provide a predetermined serialization format that includes the pre-compiled NEFF artifacts. These must be manually copied to an inference environment to be used. - **Reduced Dependencies**: When using the traced TorchScript Module in an inference environment, it is no longer required to install the |neuronx-cc| compiler. In contrast, when using the XLA |LazyTensor| execution, an execution may require a recompile to successfully infer. - **Static & Predictable**: The resulting module produced by |trace| will contain a static model that will consume a predictable amount of Neuron device memory and will never require recompilation based on input changes. In contrast, since XLA device inference performs just-in-time compilation, it can be more difficult to predict memory utilization and the compilations that may be required at inference time. - **C++ Usability**: If the end application is an inference platform using ``libtorch``, it is easy to integrate with ``libtorchneuron`` to load traced modules. It is not currently possible to set up an environment to use torch in C++ in conjunction with Neuron XLA |LazyTensor| execution. Summary ------- +----------------+-----------------------+-------------------+ | | XLA Device Inference | Traced Inference | +================+=======================+===================+ | Compilation | JIT | AOT | +----------------+-----------------------+-------------------+ | Serialization | N/A | `TorchScript`_ | +----------------+-----------------------+-------------------+ | Performance | Slower | Faster | +----------------+-----------------------+-------------------+ | Dynamic | Yes | No | +----------------+-----------------------+-------------------+ | C++ Usage | No | Yes | +----------------+-----------------------+-------------------+ .. |LazyTensor| replace:: :ref:`Lazy Tensor ` .. |trace| replace:: :func:`~torch_neuronx.trace` .. |device| replace:: :code:`xm.xla_device()` .. |neuronx-cc| replace:: :ref:`neuronx-cc ` .. _TorchScript: https://pytorch.org/docs/stable/jit.html