This document is relevant for: Inf2, Trn1, Trn2

ModelBuilderV2 API Reference#

APIs#

Examples

Usage Examples

neuronx_distributed.trace.model_builder.trace#

neuronx_distributed.trace.model_builder.trace(
    model: Union[Callable, torch.nn.Module],
    args: Union[None, torch.Tensor, Tuple[torch.Tensor, ...]] = None,
    kwargs: Optional[Dict[str, torch.Tensor]] = None,
    spmd: bool = True,
    preserve_parameters: bool = True,
) -> TraceArtifacts

The trace() function is a fundamental unit in the ModelBuilderV2 framework that handles the tracing of PyTorch models for execution on Neuron devices. It processes example inputs as both positional and keyword arguments, validates model parameters, and generates necessary trace artifacts such as HLOs.

Parameters#

  • model: Union[Callable, torch.nn.Module] — The PyTorch model or callable function to be traced. Must have explicitly defined parameters (no *args or **kwargs). Must have at least one parameter.

  • args: Union[None, torch.Tensor, Tuple[torch.Tensor, …]] = None — Example inputs as positional arguments. Can be None, a single tensor, or a tuple of tensors. Must match the model’s positional parameter requirements.

  • kwargs: Optional[Dict[str, torch.Tensor]] = None — Example inputs as keyword arguments. Must be a dictionary mapping parameter names to tensor values. Cannot override parameters provided in args.

  • spmd: bool = True — Whether to use SPMD (Single Program Multiple Data) for tracing. Currently only True is supported

  • preserve_parameters: bool = True — Whether to preserve module buffers across multi-bucket trace.

Returns#

Returns a TraceArtifacts object containing:

neuronx_distributed.trace.model_builder_utils.TraceArtifacts(
    hlo: Any,                                 # HLO representation
    metaneff: Any,                            # Meta information for NEFF
    flattener: Any,                           # Function to flatten inputs
    packer: Any,                              # Function to pack outputs
    weight_name_to_idx: Dict[str, int],       # Maps weight names to indices
    weight_names_to_skip: Set,                # Weight names excluded from optimization
    provided_args: List[ProvidedArgInfo],     # Information about provided arguments
    model_params: List[ModelParamInfo],       # Information about model parameters
)

ProvidedArgInfo object contains:

neuronx_distributed.trace.model_builder_utils.ProvidedArgInfo(
     param_name: str,       # Name of the parameter this argument corresponds to
     is_positional: bool,   # Whether this argument is positional (required) or keyword (optional)
     tensor: torch.Tensor,  # The tensor value provided for this argument
)

ModelParamInfo object contains:

neuronx_distributed.trace.model_builder_utils.ModelParamInfo(
     param_name: str,      # Name of the parameter in the function signature
     is_positional: bool,  # Whether this parameter is positional (required) or keyword (optional)
)

neuronx_distributed.trace.model_builder.compile#

neuronx_distributed.trace.model_builder.compile(
    hlo_module: hlo_pb2.HloModuleProto,
    metaneff: Any,
    compiler_workdir: Optional[Union[str, pathlib.Path]] = None,
    compiler_args: Optional[str] = None,
    key: Optional[str] = None
) -> CompilationArtifacts

The compile() function is a fundamental unit in the ModelBuilderV2 framework that compiles traced models using the Neuron Compiler, and generates Neuron Executable File Format (NEFF) files. It handles compiler configurations, workdir management, and produces compilation artifacts.

Parameters#

  • hlo_module: hlo_pb2.HloModuleProto — The HLO module representing the computational graph to be compiled. Generated from the trace() function.

  • metaneff: Any — Meta information for the Neuron Executable File Format (NEFF)

  • compiler_workdir: Optional[Union[str, pathlib.Path]] = None — Directory path to store compiler artifacts. If None, uses a default path. Creates timestamped subdirectories (in UTC format) for each compilation.

  • compiler_args: Optional[str] = None — Compiler flags for neuronx-cc. If None, uses default compiler flags. Can include optimization levels and other compiler options.

  • key: Optional[str] = None — Key to tag the bucket with a meaningful name. If None, generates a hash from the HLO module. Used for logging and artifact organization

Returns#

Returns a CompilationArtifacts object containing:

neuronx_distributed.trace.model_builder_utils.CompilationArtifacts(
    neff_filepath: str    # Path to the compiled NEFF file
)

Default Compiler Flags#

If no compiler_args are provided, the following defaults are used:

--enable-saturate-infinity --auto-cast=none --model-type=transformer -O1

Directory Structure#

This creates the following directory structure:

compiler_workdir/
└── {key}/
    └── {timestamp}/
        ├── model/
        │   └── graph.hlo
        ├── graph.neff
        ├── metaneff.pb
        └── command.txt
        └── log-neuron-cc.txt

neuronx_distributed.shard_checkpoint#

neuronx_distributed.shard_checkpoint(
    checkpoint: Dict[str, torch.Tensor],
    model: torch.nn.Module,
    start_rank: Optional[int] = None,
    end_rank: Optional[int] = None,
    load_on_device: bool = False,
    serialize_path: Optional[str] = None
) -> List[Dict[str, torch.Tensor]]

The shard_checkpoint() function shards a model checkpoint across tensor parallel ranks for distributed execution. It supports options for serialization (pre-shard) and direct loading onto Neuron devices (shard-on-load).

Parameters#

  • checkpoint: Dict[str, torch.Tensor] — The model checkpoint dictionary. Maps parameter names to tensor values. Must contain all model parameters.

  • model: torch.nn.Module — The PyTorch model to be sharded. Used for determining sharding strategy.

  • start_rank: Optional[int] = None — Starting rank for sharding. Must be in range [0, tp_degree). Defaults to 0 if None.

  • end_rank: Optional[int] = None — Ending rank for sharding. Must be in range [start_rank, tp_degree). Defaults to (tp_degree - 1) if None.

  • load_on_device: bool = False — Whether to load sharded tensors onto Neuron devices. Requires running on supported Neuron instance. Defaults to False.

  • serialize_path: Optional[str] = None — Path to save sharded checkpoints. If provided, saves as safetensors files. Creates directory if it doesn’t exist.

Returns#

Returns a List[Dict[str, torch.Tensor]] where:

  • Each dictionary represents a sharded checkpoint for a rank

  • Dictionary keys are parameter names

  • Dictionary values are sharded tensor values

  • List length is (end_rank - start_rank + 1)

neuronx_distributed.ModelBuilder#

class ModelBuilderV2:
    def __init__(
        self,
        model: Union[Callable, torch.nn.Module],
    )

ModelBuilderV2 is a high-level class that provides a fluent interface for tracing and compiling PyTorch models for Neuron devices. It supports SPMD (Single Program Multiple Data) execution, and distributed model execution.

Constructor Parameters#

  • model: Union[Callable, torch.nn.Module] — The PyTorch model to be traced and compiled. Can be a model class or callable function. Must have explicitly defined parameters (no *args or **kwargs). Must have at least one argument.

neuronx_distributed.ModelBuilder.trace#

neuronx_distributed.ModelBuilder.trace(
    self,
    args: Union[None, torch.Tensor, Tuple[torch.Tensor, ...]] = None,
    kwargs: Optional[Dict[str, torch.Tensor]] = None,
    tag: Optional[str] = None,
    spmd: bool = True,
) -> ModelBuilderV2

Traces the model with given inputs and stores trace artifacts. Leverages neuronx_distributed.trace.model_builder.trace fundamental unit.

Parameters#

  • args: Union[None, torch.Tensor, Tuple[torch.Tensor, …]] = None — Example inputs as positional arguments. Can be None, a single tensor, or a tuple of tensors. Must match the model’s positional parameter requirements.

  • kwargs: Optional[Dict[str, torch.Tensor]] = None — Example inputs as keyword arguments

  • tag: Optional[str] = None — Unique identifier for this trace. Corresponding bucket will be tagged with this name. If None, generates a hash from the HLO module.

  • spmd: bool = True — Whether to use SPMD (Single Program Multiple Data) for tracing. Currently only True is supported

Returns#

Self reference for method chaining.

neuronx_distributed.ModelBuilder.compile#

neuronx_distributed.ModelBuilder.compile(
    self,
    priority_model_key: Optional[str] = None,
    compiler_workdir: Optional[Union[str, pathlib.Path]] = None,
    compiler_args: Optional[Union[str, Dict[str, str]]] = None,
    max_workers: Optional[int] = None,
) -> NxDModel

Compiles traced models using the Neuron compiler. Leverages neuronx_distributed.trace.model_builder.compile fundamental unit.

Parameters#

  • priority_model_key: Optional[str] = None — Key of model to prioritize for WLO

  • compiler_workdir: Optional[Union[str, pathlib.Path]] = None — Directory for compiler artifacts

  • compiler_args: Optional[Union[str, Dict[str, str]]] = None — Compiler flags as string or dictionary mapping tags to flags.

  • max_workers: Optional[int] = None — Maximum worker threads for parallel compilation. If None, uses the default value from ThreadPoolExecutor.

Returns#

A built and configured NxDModel instance.

neuronx_distributed.trace.nxd_model.base_nxd_model.StateInitializer#

class StateInitializer(torch.nn.Module):
    def __init__(
        self,
        shapes: Dict[str, List[int]],
        dtypes: Dict[str, torch.dtype],
        local_ranks_size: int
    ):

A TorchScript-compatible module to initialize state buffers onto Neuron.

Constructor Parameters#

  • shapes: Dict[str, List[int]] — Dict of shape lists associated with a specific stateful tensor by key

  • dtypes: Dict[str, torch.dtype] — Dict of torch dtypes associated with a specific stateful tensor by key

  • local_ranks_size: int — integer representing the number of ranks per instance in a distributed setting. Unless it’s a Multi Instance Data Parallel setup, it is usually just equal to the world_size your model was compiled for.

neuronx_distributed.NxDModel#

class NxDModel(torch.nn.Module, BaseNxDModel):
    def __init__(
        self,
        world_size: int,
        start_rank: Optional[int] = None,
        local_ranks_size: Optional[int] = None,
        state_initializer: Optional[StateInitializer] = None,
        layout_transformer: Optional[LayoutTransformerArtifacts] = None
    )

An executor class to run models compiled by either the ModelBuilder or trace(), compile() fundamental units.

Constructor Parameters#

  • world_size: int — Total number of ranks/processes in the distributed setup.

  • start_rank: Optional[int], default=None — Starting rank for this instance. If None, defaults to 0.

  • local_ranks_size: Optional[int], default=None — Number of local ranks. Must be specified if start_rank is provided.

  • state_initializer: Optional[StateInitializer], default=None — Initializer for model states. If not provided, stateful model tensors will be initialized with zeros.

neuronx_distributed.NxDModel.add#

@torch.jit.unused
def add(
    self,
    key: str,
    trace_artifacts: TraceArtifacts,
    compilation_artifacts: Union[CompilationArtifacts, WLOArtifacts],
) -> "NxDModel"

Add a compiled submodel to this NxDModel instance.

Notes:

  • Creates a StateInitializer if state tensors are present in the metaneff, and none was provided in the NxDModel constructor

  • Sets up SPMDModel instances and input/output processing components

Parameters#

  • key: str — Unique identifier for this submodel within the NxDModel

  • trace_artifacts: TraceArtifacts — Artifacts produced from the trace() function

  • compilation_artifacts: CompilationArtifacts — Artifacts produced from the compile() or compile_wlo() functions

Returns#

NxDModel self reference, enabling builder-style method chaining.

neuronx_distributed.NxDModel.get_neff#

@torch.jit.unused
def get_neff(self, key: str) -> bytes

Retrieves the NEFF (Neuron Executable File Format) from the specified model. Requires the associated model to already be added using the add() method.

Raises: KeyError: If the specified key is not found in the available keys. RuntimeError: If there is an error retrieving the NEFF.

Parameters#

  • key: str — The identifier for the model whose NEFF should be retrieved.

Returns#

bytes — The NEFF for the specified model

neuronx_distributed.NxDModel.get_metaneff#

@torch.jit.unused
def get_metaneff(self, key: str) -> metaneff_pb2.MetaNeff

Retrieves the metaneff from the specified model. Requires the associated model to already be added using the add() method.

Raises: KeyError: If the specified key is not found in the available keys. RuntimeError: If there is an error retrieving the metaneff.

Parameters#

  • key: str — The identifier for the model whose metaneff should be retrieved.

Returns#

metaneff_pb2.MetaNeff — The metaneff proto object for the specified model.

neuronx_distributed.NxDModel.get_hlo#

@torch.jit.unused
def get_hlo(self, key: str) -> hlo_pb2.HloModuleProto

Retrieves the HLO from the specified model. Requires the associated model to already be added using the add() method.

Raises: KeyError: If the specified key is not found in the available keys. RuntimeError: If there is an error retrieving the metaneff.

Parameters#

  • key: str — **** The identifier for the model whose HLO should be retrieved.

Returns#

hlo_pb2.HloModuleProto — The HLO module proto object for the specified model.

neuronx_distributed.NxDModel.set_weights#

@torch.jit.export
def set_weights(
    self,
    sharded_checkpoint: List[Dict[str, torch.Tensor]]
)

Set the model’s weights from a sharded checkpoint.

This function initializes the model’s weights using a sharded checkpoint. The checkpoint is processed and loaded using either a layout transformer (if provided) or a direct parallel loading mechanism.

This function should only be called before the model is loaded onto a Neuron device. Once the model is loaded, use the replace_weights()method to update the weights.

Raises:

ValueError: If the model is already loaded on a Neuron device.

Parameters#

  • sharded_checkpoint: List[Dict[str, torch.Tensor]] — **** A list of state dicts mapping parameter names to their corresponding tensor values for each rank.

Returns#

None

neuronx_distributed.NxDModel.to_neuron#

@torch.jit.export
def to_neuron(self)

Loads the model onto Neuron Devices.

This function initializes the model onto Neuron Hardware. Must be called before executing the model, otherwise the forward method will raise a RuntimeError.

Returns#

None

neuronx_distributed.NxDModel.replace_weights#

@torch.jit.export
def replace_weights(
    self,
    sharded_checkpoint: List[Dict[str, torch.Tensor]]
)

Replace the model’s weights and reload onto Neuron devices.

This method should be used instead of set_weights() when the model is already loaded on Neuron devices and weights need to be updated.

Parameters#

  • sharded_checkpoint: List[Dict[str, torch.Tensor]] — **** A list of state dicts mapping parameter names to their corresponding tensor values for each rank.

Returns#

None

neuronx_distributed.NxDModel.read_from_neuron_buffer#

@torch.jit.export
def read_from_neuron_buffer(
    self,
    buffer_key: str,
    rank: int
) -> torch.Tensor

Reads a tensor value from a Neuron device buffer to CPU, based on given key and rank.

Raises: AssertionError: If this method is called before to_neuron() KeyError: If the specified state_buffer_key does not exist in the states for the given rank.

Parameters#

  • buffer_key: str — **** The key identifying the specific buffer to retrieve.

  • rank: int — **** The rank from which to retrieve the buffer.

Returns#

torch.Tensor: The requested tensor buffer copied to Host memory.

neuronx_distributed.NxDModel.write_to_neuron_buffer#

@torch.jit.export
def write_to_neuron_buffer(
    self,
    tensor: torch.Tensor,
    buffer_key: str,rank: int
)

Write a tensor to a specific Neuron device buffer.

This function updates a state buffer on a Neuron device by copying values from the provided tensor. The destination buffer must already exist and have the same shape as the input tensor.

Raises: AssertionError: If this method is called before to_neuron() KeyError: If the specified state_buffer_key does not exist in the states for the given rank, or if the shapes of the input tensor and target buffer do not match.

Parameters#

  • tensor: torch.Tensor — **** The tensor containing the data to be written to the buffer.

  • buffer_key: str — **** The key identifying the specific buffer to update.

  • rank: int — The rank where the buffer is located.

Returns#

None

neuronx_distributed.NxDModel.forward#

def forward(
    self,
    *args,
    model_name: Optional[str] = None,
    forward_mode='default',
   **kwargs
):

The forward method of the NxDModel class, which will take in inputs and run the respective neff.

Raises: AssertionError RuntimeError KeyError

Parameters#

  • args: Union[torch.Tensor, List[torch.Tensor]] *** Positional tensor inputs to model. List form must be used if forward_mode != 'default'.

  • model_name: Optional[str] — **** Parameter to pass in a specific key to execute. This must be used in cases of ambiguous routing.

  • forward_mode: str, default=‘default’ — **** There are 3 supported modes: default, ranked, async.

    • default: This takes in inputs, replicates them across ranks, executes the model, and only returns the outputs from rank 0

    • ranked: This takes in inputs in ranked form, meaning each individual tensor input (ie each arg in *args)must be a list of tensors whose length is equal to the world size of the compiled model. The model will execute, and return a ranked output, which is a List of all outputs by rank (ie a List[List[torch.Tensor]].

    • async: Like ranked, this takes in inputs and returns outputs in ranked form, except the major difference is that the outputs will be returned instantly, and will be references to buffers where the model will write the output once the neff is done executing. To block on the neff call, you must call .cpu() for each tensor in the output.

  • ****kwargs (torch.Tensor, List[torch.Tensor]):** Keyword arguments corresponding to specific input tensors to the model. List form must be used if forward_mode != 'default'.

Returns#

It depends on the forward_mode setting: default: Expected format of tensor outputs based on what was originally traced. ranked or async: List[List[torch.Tensor]] **** of shape (num_out_tensors, world_size)

neuronx_distributed.NxDModel.save#

def save(self, path_to_save: str, save_weights: bool = False)

Saves the model as a TorchScript module to the specified path. The saved artifact can be loaded with NxDModel.load or torch.jit.load (NxDModel.load is preferrable).

Parameters#

  • path_to_save: str — **** The file path where the TorchScript model should be saved.

  • save_weights: Optional[bool], default=False — **** If True, preserves the weights within the TorchScript model. It is False by default.

Returns#

None

neuronx_distributed.NxDModel.load#

@classmethod
def load(
    cls,
    path_to_model: str,
    start_rank: Optional[int] = None,
    local_ranks_size: Optional[int] = None
) -> Union["NxDModel", torch.jit.ScriptModule]

Attempts to load and restore an NxDModel from a saved TorchScript model.

This classmethod tries to reconstruct an NxDModel instance from a previously saved TorchScript model. If the restoration process fails, it returns the loaded TorchScript model instead, as backwards compatibility is not guaranteed across different versions of NxD.

Raises: ValueError: If the provided model was not originally saived using NxDModel.save() AssertionError: If start_rank/local_ranks_size parameters are inconsistently set.

Parameters#

  • path_to_model: str — **** Path to the saved TorchScript model file.

  • start_rank: Optional[int], default=None — **** Starting rank for distributed processing. If None, and local_ranks_size is set, an AssertionError will be raised. Defaults to None

  • local_ranks_size: Optional[int], default=None — **** Size of local_ranks for distribtued processing. Must be set if start_rank is provided. Defaults to None

Returns#

Union[NxDModel, torch.jit.ScriptModule]: Either the restored NxdModel instance, or the loaded TorchScript model if restoration fails.

Usage Notes#

In-place buffer updates#

Description#

ModelBuilderV2 enables users to update model buffers in-place during their model’s forward pass. In-place updates enable users to efficiently utilize memory when caching values during the forward pass. An example use case for in-place updates is the population of a model’s KV Cache.

Under the hood, ModelBuilderV2 detects when buffers are mutated during forward while tracing a model, and uses XLA’s aliasing to ensure that buffers are mutated in-place.

Supported Usage#

In-place updates are currently supported for the following combinations of torch.Tensor subclasses and torch operations:

Tensor class

Out of place torch operation

In place torch operation

torch.nn.Buffer, persistent=True

Supported

Not Supported

torch.nn.Buffer, persistent=False

Supported

Not Supported

torch.nn.Parameter

Not Supported

Not Supported

Additionally, the following forms of updates are not supported, because these mutations change the memory utilization or memory layout of the mutated tensor:

  • Updating the dtype of a buffer or parameter during forward.

  • Updating the shape of a buffer or parameter during forward.

Supported Usage:#

import torch
import torch.nn as nn

class ExampleModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.register_buffer("buffer_persistent", torch.zeros(10), dtype=torch.bfloat16, persistent=True)
        self.register_buffer("buffer_nonpersistent", torch.zeros(10), dtype=torch.bfloat16, persistent=False)
        self.parameter = nn.Parameter(torch.zeros(10), dtype=torch.bfloat16)

    def forward(self, x, dim_tensor, index, src):
        # supported: buffers with out of place torch operations
        self.buffer_persistent = self.buffer_persistent + 1
        self.buffer_nonpersistent = torch.scatter(self.buffer_persistent, dim_tensor, index, src)

        # not supported: buffers with inplace torch operations
        self.buffer_persistent.scatter_(dim_tensor, index, src)
        self.buffer_nonpersistent.index_copy_(dim_tensor, index, src)

        # not supported: parameters
        self.parameter = torch.scatter(self.paramter, dim_tensor, index, src)
        self.parameter.scatter_(dim_tensor, index, src)

        # not supported: dtype updates
        self.buffer_persistent = self.buffer_persistent.to(torch.float32)

        # not supported: shape changes
        self.buffer_persistent = torch.reshape(self.buffer_persistent.reshape, (2, 5))

Usage Examples#

E2E with ModelBuilder APIs#

With Callable

import torch
import torch.nn as nn
from neuronx_distributed import ModelBuilder

torch.manual_seed(0)

def func(a, b):
    return a + b

nxd_model = ModelBuilder(func) \
    .trace(kwargs={'a': torch.rand(2,2), 'b': torch.rand(2,2)}, tag="key1") \
    .compile()

nxd_model.to_neuron()
input = (torch.rand(2, 2), torch.rand(2, 2))
cpu_out = func(a=input[0], b=input[1])
neuron_out = nxd_model(a=input[0], b=input[1])

torch.testing.assert_close(cpu_out, neuron_out)

With ``torch`` module

import torch
import torch.nn as nn
from neuronx_distributed.utils.model_utils import init_on_device
from neuronx_distributed import NxDParallelState, shard_checkpoint, ModelBuilder
from neuronx_distributed.parallel_layers import ColumnParallelLinear, RowParallelLinear

torch.manual_seed(0)

class Model(nn.Module):
    def __init__(self, is_distributed=True):
        super().__init__()
        if is_distributed:
            self.layer1 = ColumnParallelLinear(1024, 1024, gather_output=False)
            self.layer2 = RowParallelLinear(1024, 1024, input_is_parallel=True)
        else:
            self.layer1 = nn.Linear(1024, 1024)
            self.layer2 = nn.Linear(1024, 1024)
    def forward(self, x):
        x = self.layer1(x)
        return self.layer2(x)

cpu_model = Model(is_distributed=False)
model_checkpoint = cpu_model.state_dict()

with NxDParallelState(world_size=32, tensor_model_parallel_size=32):
    model = Model()

    example_inputs = torch.rand(32, 1024)

    nxd_model = ModelBuilder(model) \
        .trace(args=example_inputs, tag="key1") \
        .compile()

with NxDParallelState(world_size=32, tensor_model_parallel_size=32), init_on_device(torch.device("meta")):
    sharded_checkpoint = shard_checkpoint(
        checkpoint=model_checkpoint,
        model=Model()
    )

nxd_model.set_weights(sharded_checkpoint)
nxd_model.to_neuron()

input = torch.ones(32, 1024)
cpu_out = cpu_model(input)
neuron_out = nxd_model(x=input)

Multi-bucket trace

import torch
import torch.nn as nn
from neuronx_distributed.utils.model_utils import init_on_device
from neuronx_distributed import NxDParallelState, shard_checkpoint, ModelBuilder
from neuronx_distributed.parallel_layers import ColumnParallelLinear

torch.manual_seed(0)

class Model(nn.Module):
    def __init__(self, is_distributed=True):
        super().__init__()
        if is_distributed:
            self.layer1 = ColumnParallelLinear(1024, 1024, gather_output=True)
            self.layer2 = ColumnParallelLinear(1024, 1024, gather_output=True)
        else:
            self.layer1 = nn.Linear(1024, 1024)
            self.layer2 = nn.Linear(1024, 1024)
    def forward(self, x):
        x = self.layer1(x)
        return self.layer2(x)

cpu_model = Model(is_distributed=False)
model_checkpoint = cpu_model.state_dict()

with NxDParallelState(world_size=32, tensor_model_parallel_size=32):
    model = Model()

    example_inputs1 = torch.rand(32, 1024)
    example_inputs2 = torch.rand(16, 1024)

    nxd_model = ModelBuilder(model) \
        .trace(args=example_inputs1, tag="bucket1") \
        .trace(args=example_inputs2, tag="bucket2") \
        .compile()


with NxDParallelState(world_size=32, tensor_model_parallel_size=32), init_on_device(torch.device("meta")):
    sharded_checkpoint = shard_checkpoint(
        checkpoint=model_checkpoint,
        model=Model()
    )

nxd_model.set_weights(sharded_checkpoint)
nxd_model.to_neuron()

input1 = torch.rand(32, 1024)
input2 = torch.rand(16, 1024)

for input in [input1, input2]:
    cpu_out = cpu_model(input)
    neuron_out = nxd_model(input)
    torch.testing.assert_close(cpu_out, neuron_out)

Example inputs supplied as kwargs

import torch
import torch.nn as nn
from neuronx_distributed.utils.model_utils import init_on_device
from neuronx_distributed import NxDParallelState, shard_checkpoint, ModelBuilder
from neuronx_distributed.parallel_layers.layers import ColumnParallelLinear

torch.manual_seed(0)

class Model(nn.Module):
    def __init__(self, is_distributed=True):
        super().__init__()
        if is_distributed:
            self.layer1 = ColumnParallelLinear(5, 10, gather_output=True)
            self.layer2 = ColumnParallelLinear(20, 10, gather_output=True)
        else:
            self.layer1 = nn.Linear(5, 10)
            self.layer2 = nn.Linear(20, 10)

    def forward(self, x, y):
        return self.layer1(x) + self.layer2(y)

cpu_model = Model(is_distributed=False)
model_checkpoint = cpu_model.state_dict()

with NxDParallelState(world_size=2, tensor_model_parallel_size=2):
    model = Model()

    example_inputs1 = {'x': torch.rand(10, 5), 'y': torch.rand(10, 20)}
    example_inputs2 = {'x': torch.rand(50, 5), 'y': torch.rand(50, 20)}

    nxd_model = ModelBuilder(model) \
        .trace(kwargs=example_inputs1, tag="bucket1") \
        .trace(kwargs=example_inputs2, tag="bucket2") \
        .compile()


with NxDParallelState(world_size=2, tensor_model_parallel_size=2), init_on_device(torch.device("meta")):
    sharded_checkpoint = shard_checkpoint(
        checkpoint=model_checkpoint,
        model=Model()
    )

nxd_model.set_weights(sharded_checkpoint)
nxd_model.to_neuron()

input1 = (torch.rand(10, 5), torch.rand(10, 20))
input2 =  (torch.rand(50, 5), torch.rand(50, 20))

for input in [input1, input2]:
    cpu_out = cpu_model(input[0], input[1])
    neuron_out = nxd_model(x=input[0], y=input[1])
    torch.testing.assert_close(cpu_out, neuron_out)

With in-place buffer updates

import torch
from neuronx_distributed import ModelBuilder

torch.manual_seed(0)

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer('cache', torch.tensor([0], dtype=torch.float32), persistent=True)

    def forward(self, x, update_value):
        self.cache = torch.add(self.cache, update_value)
        return x + self.cache

cpu_model = Model()

model = Model()

example_inputs1 = {'x': torch.zeros(1, dtype=torch.float32), 'update_value': torch.zeros(1, dtype=torch.float32)}

nxd_model = ModelBuilder(model) \
    .trace(kwargs=example_inputs1, tag="bucket1") \
    .compile()

state_dict = [
    {
        "cache": torch.tensor([0])
    }
]
nxd_model.set_weights(state_dict)
nxd_model.to_neuron()

input1 = (torch.tensor([1], dtype=torch.float32), torch.tensor([5], dtype=torch.float32))
input2 =  (torch.tensor([2], dtype=torch.float32), torch.tensor([10], dtype=torch.float32))

model_iteration = 0
for input in [input1, input2]:
    cpu_out = cpu_model(input[0], input[1])
    neuron_out = nxd_model(x=input[0], update_value=input[1])

    torch.testing.assert_close(cpu_out, neuron_out)
    model_iteration += 1
    print(f"Iteration {model_iteration} matches!")

E2E with Fundamental Units#

With Callable

import torch
from neuronx_distributed import NxDModel
from neuronx_distributed.trace.model_builder import trace, compile

torch.manual_seed(0)

def func(a,b):
    return a + b

trace_artifacts = trace(func, kwargs={'a': torch.rand(2,2), 'b': torch.rand(2,2)})
compilation_artifacts = compile(trace_artifacts.hlo, trace_artifacts.metaneff)

nxd_model = NxDModel(world_size=1)
nxd_model.add('func', trace_artifacts, compilation_artifacts)
nxd_model.to_neuron()

cpu_out = func(torch.ones(2, 2), torch.ones(2, 2))
neuron_out = nxd_model(torch.ones(2,2), torch.ones(2,2))
torch.testing.assert_close(cpu_out, neuron_out)

With ``torch`` module

import os
import shutil
import torch
import torch.nn as nn

from neuronx_distributed.utils.model_utils import init_on_device
from neuronx_distributed import NxDParallelState, shard_checkpoint, ModelBuilder, NxDModel
from neuronx_distributed.parallel_layers import ColumnParallelLinear, RowParallelLinear
from neuronx_distributed.trace.model_builder_utils import ModelBuilderConstants
from neuronx_distributed.trace.model_builder import (
    trace,
    compile,
)

torch.manual_seed(0)

class Model(nn.Module):
    def __init__(self, is_distributed=True):
        super().__init__()
        if is_distributed:
            self.layer1 = ColumnParallelLinear(1024, 1024, gather_output=False)
            self.layer2 = RowParallelLinear(1024, 1024, input_is_parallel=True)
        else:
            self.layer1 = nn.Linear(1024, 1024)
            self.layer2 = nn.Linear(1024, 1024)
    def forward(self, x):
        x = self.layer1(x)
        return self.layer2(x)

cpu_model = Model(is_distributed=False)
model_checkpoint = cpu_model.state_dict()

with NxDParallelState(world_size=32, tensor_model_parallel_size=32):
    model = Model()

    example_inputs = torch.rand(32, 1024)

    trace_artifacts = {
        "bucket1": trace(model, args=example_inputs),
    }

    compilation_artifacts_priority = compile(
        hlo_module=trace_artifacts["bucket1"].hlo,
        metaneff=trace_artifacts["bucket1"].metaneff,
        key="bucket1"
    )

with NxDParallelState(world_size=32, tensor_model_parallel_size=32), init_on_device(torch.device("meta")):
    sharded_checkpoint = shard_checkpoint(
        checkpoint=model_checkpoint,
        model=Model()
    )

nxd_model = NxDModel(world_size=32)
nxd_model.add(key="bucket1", trace_artifacts=trace_artifacts["bucket1"], compilation_artifacts=compilation_artifacts_priority)

nxd_model.set_weights(sharded_checkpoint)
nxd_model.to_neuron()

input = torch.rand(32, 1024)

cpu_out = cpu_model(input)
neuron_out = nxd_model(input)
torch.testing.assert_close(cpu_out, neuron_out)

Multi-bucket trace

import os
import shutil
import torch
import torch.nn as nn

from neuronx_distributed.utils.model_utils import init_on_device
from neuronx_distributed import NxDParallelState, shard_checkpoint, ModelBuilder, NxDModel
from neuronx_distributed.parallel_layers import ColumnParallelLinear, RowParallelLinear
from neuronx_distributed.trace.model_builder_utils import ModelBuilderConstants
from neuronx_distributed.trace.model_builder import (
    trace,
    compile,
)

torch.manual_seed(0)

class Model(nn.Module):
    def __init__(self, is_distributed=True):
        super().__init__()
        if is_distributed:
            self.layer1 = ColumnParallelLinear(1024, 1024, gather_output=False)
            self.layer2 = RowParallelLinear(1024, 1024, input_is_parallel=True)
        else:
            self.layer1 = nn.Linear(1024, 1024)
            self.layer2 = nn.Linear(1024, 1024)
    def forward(self, x):
        x = self.layer1(x)
        return self.layer2(x)

cpu_model = Model(is_distributed=False)
model_checkpoint = cpu_model.state_dict()

with NxDParallelState(world_size=32, tensor_model_parallel_size=32):
    model = Model()

    example_inputs1 = torch.rand(32, 1024)
    example_inputs2 = torch.rand(16, 1024)

    trace_artifacts = {
        "bucket1": trace(model, args=example_inputs1),
        "bucket2": trace(model, args=example_inputs2),
    }

    compilation_artifacts_bucket1 = compile(
        hlo_module=trace_artifacts["bucket1"].hlo,
        metaneff=trace_artifacts["bucket1"].metaneff,
        key="bucket1"
    )
    compilation_artifacts_bucket2 = compile(
        hlo_module=trace_artifacts["bucket2"].hlo,
        metaneff=trace_artifacts["bucket2"].metaneff,
        key="bucket2"
    )

with NxDParallelState(world_size=32, tensor_model_parallel_size=32), init_on_device(torch.device("meta")):
    sharded_checkpoint = shard_checkpoint(
        checkpoint=model_checkpoint,
        model=Model()
    )

nxd_model = NxDModel(world_size=32)
nxd_model.add(key="bucket1", trace_artifacts=trace_artifacts["bucket1"], compilation_artifacts=compilation_artifacts_bucket1)
nxd_model.add(key="bucket2", trace_artifacts=trace_artifacts["bucket2"], compilation_artifacts=compilation_artifacts_bucket2)

nxd_model.set_weights(sharded_checkpoint)
nxd_model.to_neuron()

input1 = torch.rand(32, 1024)
input2 = torch.rand(16, 1024)

for input in [input1, input2]:
    cpu_out = cpu_model(input)
    neuron_out = nxd_model(input)
    torch.testing.assert_close(cpu_out, neuron_out)

Example inputs supplied as kwargs

import os
import shutil
import torch
import torch.nn as nn

from neuronx_distributed.utils.model_utils import init_on_device
from neuronx_distributed import NxDParallelState, shard_checkpoint, ModelBuilder, NxDModel
from neuronx_distributed.parallel_layers import ColumnParallelLinear, RowParallelLinear
from neuronx_distributed.trace.model_builder_utils import ModelBuilderConstants
from neuronx_distributed.trace.model_builder import (
    trace,
    compile,
)

torch.manual_seed(0)

class Model(nn.Module):
    def __init__(self, is_distributed=True):
        super().__init__()
        if is_distributed:
            self.linear1 = ColumnParallelLinear(5, 10, gather_output=True)
            self.linear2 = ColumnParallelLinear(20, 10, gather_output=True)
        else:
            self.linear1 = nn.Linear(5, 10)
            self.linear2 = nn.Linear(20, 10)

    def forward(self, x, y):
        return self.linear1(x) + self.linear2(y)

cpu_model = Model(is_distributed=False)
model_checkpoint = cpu_model.state_dict()

with NxDParallelState(world_size=2, tensor_model_parallel_size=2):
    model = Model()

    example_inputs1 = {'x': torch.rand(10, 5), 'y': torch.rand(10, 20)}
    example_inputs2 = {'x': torch.rand(50, 5), 'y': torch.rand(50, 20)}

    trace_artifacts = {
        "bucket1": trace(model, kwargs=example_inputs1),
        "bucket2": trace(model, kwargs=example_inputs2),
    }

    compilation_artifacts_bucket1 = compile(
        hlo_module=trace_artifacts["bucket1"].hlo,
        metaneff=trace_artifacts["bucket1"].metaneff,
        key="bucket1"
    )
    compilation_artifacts_bucket2 = compile(
        hlo_module=trace_artifacts["bucket2"].hlo,
        metaneff=trace_artifacts["bucket2"].metaneff,
        key="bucket2"
    )

with NxDParallelState(world_size=2, tensor_model_parallel_size=2), init_on_device(torch.device("meta")):
    sharded_checkpoint = shard_checkpoint(
        checkpoint=model_checkpoint,
        model=Model()
    )

nxd_model = NxDModel(world_size=2)
nxd_model.add(key="bucket1", trace_artifacts=trace_artifacts["bucket1"], compilation_artifacts=compilation_artifacts_bucket1)
nxd_model.add(key="bucket2", trace_artifacts=trace_artifacts["bucket2"], compilation_artifacts=compilation_artifacts_bucket2)

nxd_model.set_weights(sharded_checkpoint)
nxd_model.to_neuron()

input1 = (torch.rand(10, 5), torch.rand(10, 20))
input2 =  (torch.rand(50, 5), torch.rand(50, 20))

for input in [input1, input2]:
    cpu_out = cpu_model(input[0], input[1])
    neuron_out = nxd_model(x=input[0], y=input[1])
    torch.testing.assert_close(cpu_out, neuron_out)

With in-place buffer updates#

import torch

from neuronx_distributed import NxDModel
from neuronx_distributed.trace.model_builder import (
    trace,
    compile,
)

torch.manual_seed(0)

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer('cache', torch.tensor([0], dtype=torch.float32), persistent=True)

    def forward(self, x, update_value):
        self.cache = torch.add(self.cache, update_value)
        return x + self.cache

cpu_model = Model()

model = Model()

example_inputs1 = {'x': torch.zeros(1, dtype=torch.float32), 'update_value': torch.zeros(1, dtype=torch.float32)}

trace_artifacts = {
    "bucket1": trace(model, kwargs=example_inputs1),
}

compilation_artifacts_bucket1 = compile(
    hlo_module=trace_artifacts["bucket1"].hlo,
    metaneff=trace_artifacts["bucket1"].metaneff,
    key="bucket1"
)


nxd_model = NxDModel(world_size=1)
nxd_model.add(key="bucket1", trace_artifacts=trace_artifacts["bucket1"], compilation_artifacts=compilation_artifacts_bucket1)

state_dict = [
    {
        "cache": torch.tensor([0], dtype=torch.float32)
    }
]
nxd_model.set_weights(state_dict)
nxd_model.to_neuron()

input1 = (torch.tensor([1], dtype=torch.float32), torch.tensor([5], dtype=torch.float32))
input2 =  (torch.tensor([2], dtype=torch.float32), torch.tensor([10], dtype=torch.float32))

model_iteration = 0
for input in [input1, input2]:
    cpu_out = cpu_model(input[0], input[1])
    neuron_out = nxd_model(x=input[0], update_value=input[1])

    torch.testing.assert_close(cpu_out, neuron_out)
    model_iteration += 1
    print(f"Iteration {model_iteration} matches!")

This document is relevant for: Inf2, Trn1, Trn2