.. _trace-vs-xla-lazytensor:

Comparison of Traced Inference versus XLA |LazyTensor| Inference (``torch-neuronx``)
=====================================================================================

.. contents:: Table of contents
   :local:
   :depth: 1

Introduction
------------


Using ``torch-neuronx``, there are two ways that a model can be
executed for inference:

- **XLA LazyTensor Inference**: A model is executed on Neuron by calling
  :meth:`~torch.Tensor.to` to move :class:`~torch.nn.parameter.Parameter`
  and :class:`~torch.Tensor` data using the |device|. Executing operations uses
  torch |LazyTensor| to record, compile, and execute the graph. These are the
  same mechanisms used for :ref:`training <pytorch-neuronx-programming-guide>`.

- **(Recommended) Traced Inference**: A model is traced prior to inference
  using the |trace| API. This trace is similar to :func:`torch.jit.trace` but
  instead creates a Neuron-specific `TorchScript`_ artifact. This artifact
  provides improved performance and portability compared to XLA
  |LazyTensor| inference.




.. _xla_lazytensor:

XLA Lazy Tensor Inference Mechanics
-----------------------------------

XLA |LazyTensor| inference uses Just-In-Time (JIT) compilation for Neuron
execution.

XLA Device execution uses the built-in ``torch-xla`` functionality with torch
|LazyTensor| to record torch operations using the |device|. The graph of
operations is sent to the |neuronx-cc| compiler upon calling
``xm.mark_step()``. Finally the compiled graph is transferred to a NeuronCore
and executed in the Neuron backend.

The initial model inference will be very slow since the model binary file in the
Neuron Executable File Format (NEFF) will need to be generated by the compiler.
Upon each subsequent call to a model, the application will re-execute the
python, rebuild the graph, and check a cache to see if an existing NEFF file
is available for the given graph before attempting to recompile.

The process of recording graph operations in python can become a bottleneck for
otherwise fast models. This overhead will always have an effect on performance
regardless of model size but may be less noticeable on larger models. Note that
this XLA |LazyTensor| execution performance may improve significantly with new
torch features in the future.

Example
~~~~~~~

.. tab-set::

    .. tab-item:: Fixed Shape Example

        .. code-block:: python

            import torch
            import torch_neuronx
            import torch_xla.core.xla_model as xm

            # Create XLA device
            device = xm.xla_device()

            # Load example model and inputs to Neuron device
            model = torch.nn.Sequential(
                torch.nn.Linear(784, 120),
                torch.nn.ReLU(),
                torch.nn.Linear(120, 10),
                torch.nn.Softmax(dim=-1),
            )
            model.eval()
            model.to(device)
            example = torch.rand((1, 784), device=device)

            # Inference
            with torch.no_grad():
                result = model(example)
                xm.mark_step()  # Compilation occurs here
                print(result.cpu())


    .. tab-item:: Dynamic Shape Example

        The following is an example of a model that dynamically changes the
        sequence length and batch size of the input token ID tensor to trigger
        recompilations. This kind of workflow would require padding when using
        traced inference.

        .. code-block:: python

            import torch
            import torch_neuronx
            import torch_xla.core.xla_model as xm

            # Create XLA device
            device = xm.xla_device()

            # Load example model and inputs to Neuron device
            model = torch.nn.Sequential(
                torch.nn.Embedding(num_embeddings=30522, embedding_dim=512),
                torch.nn.Linear(512, 128),
                torch.nn.ReLU(),
                torch.nn.Linear(128, 2),
                torch.nn.Softmax(dim=-1),
            )
            model.eval()
            model.to(device)

            token_ids_1 = torch.tensor([
                [1, 28, 748, 0],
            ]) # shape: [1, 4]
            token_ids_2 = torch.tensor([
                [1, 13087, 10439, 1990, 18912, 0],
                [1, 12009, 7849, 2509, 3500, 0],
            ])  # shape: [2, 6]

            # Inference
            with torch.no_grad():

                # First compilation/inference
                result = model(token_ids_1)
                xm.mark_step()
                print(result.cpu())  # shape: [1, 4, 2]

                # Recompilation occurs here since token_ids_2 is a different shape. This infer
                # would have failed if the model had been traced with shape [1, 4]
                result = model(token_ids_2)
                xm.mark_step()
                print(result.cpu())  # shape: [2, 6, 2]



Traced Inference Mechanics
--------------------------
Traced inference uses Ahead-Of-Time (AOT) compilation for Neuron execution.

Similar to XLA |LazyTensor| inference, |trace| uses the operation recording
mechanisms provided by ``torch-xla`` to build the graph structure. This graph
structure is also sent to the |neuronx-cc| compiler to produce a binary (NEFF)
that is executable on Neuron.

The main difference is that the call to |trace| returns a *new* fully
compiled graph as a `TorchScript`_ Module. Upon calling this new Module, rather
than re-executing the python, rebuilding the graph, and checking
the cache for a matching model, the new Module simply executes the precompiled
graph that was preloaded during tracing. This is a significantly
more optimized runtime since it avoids the python operator tracing, graph
building, etc.

One disadvantage of this interface is that a model will never dynamically
recompile after a trace. This means that dynamic control flow is not supported
within a function/module. Tensor input/output shapes are fixed to the shapes
passed to the |trace| API. Dynamic batching and bucketing can be used to avoid
the pitfalls of static shapes.

Example
~~~~~~~
.. code-block:: python

    import torch
    import torch_neuronx

    # Create example model and inputs
    model = torch.nn.Sequential(
        torch.nn.Linear(784, 120),
        torch.nn.ReLU(),
        torch.nn.Linear(120, 10),
        torch.nn.Softmax(dim=-1),
    )
    model.eval()
    example = torch.rand((1, 784))

    # Create fixed model trace
    trace = torch_neuronx.trace(model, example)

    # Inference
    result = trace(example) # No recompilation. Input shapes must not change
    print(result)



Traced Inference Advantages
---------------------------

Traced inference should be used for nearly all deployment purposes since it
provides some key advantages over XLA |LazyTensor| execution:

- **Reduced Overhead**: There is no overhead associated with graph recording,
  compilation, and model loading since these steps are performed only once
  within the call to |trace|. In contrast, when using XLA |LazyTensor|
  inference, all of these steps are performed just-in-time (with caching to
  improve performance).
- **Serializable**: The TorchScript Module that is produced from the |trace| API
  is serializable using the normal :func:`torch.jit.save` function. It is able
  to be reloaded in an inference environment with :func:`torch.jit.load`.
  In contrast, XLA device inference does not provide a predetermined
  serialization format that includes the pre-compiled NEFF artifacts. These
  must be manually copied to an inference environment to be used.
- **Reduced Dependencies**: When using the traced TorchScript Module in an
  inference environment, it is no longer required to install the
  |neuronx-cc| compiler. In contrast, when using the XLA |LazyTensor|
  execution, an execution may require a recompile to successfully infer.
- **Static & Predictable**: The resulting module produced by |trace| will
  contain a static model that will consume a predictable amount of Neuron device
  memory and will never require recompilation based on input changes. In
  contrast, since XLA device inference performs just-in-time compilation, it
  can be more difficult to predict memory utilization and the compilations
  that may be required at inference time.
- **C++ Usability**: If the end application is an inference platform using
  ``libtorch``, it is easy to integrate with ``libtorchneuron`` to load
  traced modules. It is not currently possible to set up an environment to use
  torch in C++ in conjunction with Neuron XLA |LazyTensor| execution.

Tensor Materialization During Tracing
---------------------------------------

While tensor materialization is normal for JIT workflow, it is not expected during traced inference.
When working with traced inference, developers may encounter tensor materialization, which leads to graphs being compiled based on example input tensor value and unexpected program behavior.
Therefore we need to take advantage of PyTorch/XLA's debugging flags to identify when unexpected tensor materialization happens and make appropriate code changes to avoid tensor materialization.


A common issue occurs when tensor values are evaluated during model compilation (traced inference). Consider this example:

.. code-block:: python

   def forward(self, tensor):
       if tensor[0] == 1:
           return tensor
       else:
           return tensor * 2

While this code can compile and run, it may lead to unexpected behavior because:

* The tensor value is being accessed during tracing (``tensor[0]``)
* The resulting graph becomes fixed based on the tensor value available during tracing
* Developers might incorrectly assume the condition will be evaluated dynamically during inference
* The solution for the code above is to utilize the debugging flags below to catch the issue and modify the code

See the updated code without tensor materialization:

.. code-block:: python

  class TestModel(torch.nn.Module):
      def __init__(self, flag=1):
          super().__init__()
          # the flag should be pre-determined based on the model configuration
          # it should not be an input of the model during runtime
          self.flag = flag

      def forward(self, tensor):
          if self.flag:
              return tensor
          else:
              return tensor * 2

Debugging Flags
~~~~~~~~~~~~~~~~

To help catch tensor materialization issues, PyTorch/XLA provides two useful approaches:

1. Enable warning messages for tensor materialization:

.. code-block:: python

   import os
   os.environ['PT_XLA_DEBUG_LEVEL'] = '2'

2. Disable graph execution to catch issues during development:

.. code-block:: python

   import torch_xla
   torch_xla._XLAC._set_allow_execution(False)

Recommendations
~~~~~~~~~~~~~~~~

Using these flags during development can help identify potential issues early in the development cycle. The recommended approach is to:

* Use ``PT_XLA_DEBUG_LEVEL=2`` during initial development to identify potential materialization points
* Apply ``_set_allow_execution(False)`` when you want to ensure no tensor materialization occurs during tracing
* When you see warnings or errors related the tensor materialization, look into the code path and make appropriate changes. The example above moved the flag to the ``__init__`` function which does not depend on the model input during runtime.

For more detailed debugging information, refer to the `XLA PyTorch on XLA Devices <https://github.com/pytorch/xla/blob/master/docs/source/learn/pytorch-on-xla-devices.md>`__.


Summary
-------

+----------------+-----------------------+-------------------+
|                | XLA Device Inference  | Traced Inference  |
+================+=======================+===================+
| Compilation    | JIT                   | AOT               |
+----------------+-----------------------+-------------------+
| Serialization  | N/A                   | `TorchScript`_    |
+----------------+-----------------------+-------------------+
| Performance    | Slower                | Faster            |
+----------------+-----------------------+-------------------+
| Dynamic        | Yes                   | No                |
+----------------+-----------------------+-------------------+
| C++ Usage      | No                    | Yes               |
+----------------+-----------------------+-------------------+


.. |LazyTensor| replace:: :ref:`Lazy Tensor <xla_lazytensor>`
.. |trace| replace:: :func:`~torch_neuronx.trace`
.. |device| replace:: :code:`xm.xla_device()`
.. |neuronx-cc| replace:: :ref:`neuronx-cc <neuron-compiler-cli-reference-guide>`
.. _TorchScript: https://pytorch.org/docs/stable/jit.html
