This document is relevant for: Inf2, Trn1, Trn2

PyTorch Neuron (torch-neuronx) Weight Replacement API for Inference#

torch_neuronx.replace_weights(neuron_model, weights)#

Replaces the weights in a Neuron Model with split weights. This function will emit a warning of the supplied Neuron model does not contain any separated weights.

Warning

The below API is only applicable for models traced with the parameter inline_weights_to_neff=False, which is True by default. See torch_neuronx.trace() for details.

Parameters:
  • neuron_model (RecursiveScriptModule) – A Neuron model compiled with split weights

  • weights (Module,Dict[str, Tensor]) – Either the original model with the new weights, or the state_dict of a model.

Returns:

None, this function performs the weight replacement inline.

Return type:

None

Examples

Using a model

import torch
import torch_neuronx


class Network(torch.nn.Module):
    def __init__(self, hidden_size=4, layers=3) -> None:
        super().__init__()
        self.layers = torch.nn.Sequential(
            *(torch.nn.Linear(hidden_size, hidden_size) for _ in range(layers)))

    def forward(self, tensor):
        return self.layers(tensor)


# initialize two networks
network = Network()
network2 = Network()
network.eval()
network2.eval()

inp = torch.rand(2,4)

# trace weight separated model with first network
weight_separated_trace = torch_neuronx.trace(network,inp,inline_weights_to_neff=False)

# replace with weights from second network
torch_neuronx.replace_weights(weight_separated_trace,network2.state_dict())

# get outputs from neuron and cpu networks
out_network2 = network2(inp)
out_neuron = weight_separated_trace(inp)

# check that they are equal
print(out_network2,out_neuron)

Using safetensors

The safetensors library is useful for storing/loading model tensors safely and quickly.

import torch
import torch_neuronx

from safetensors import safe_open
from safetensors.torch import save_model


class Network(torch.nn.Module):
    def __init__(self, hidden_size=4, layers=3) -> None:
        super().__init__()
        self.layers = torch.nn.Sequential(
            *(torch.nn.Linear(hidden_size, hidden_size) for _ in range(layers)))

    def forward(self, tensor):
        return self.layers(tensor)


# initialize two networks
network = Network()
network2 = Network()
network.eval()
network2.eval()

inp = torch.rand(2,4)

# trace weight separated model with first network
weight_separated_trace = torch_neuronx.trace(network,inp,inline_weights_to_neff=False)

# save network2 weights to safetensors
safetensor_path = f"{directory}/network2.safetensors"
save_model(network2,safetensor_path)

#load safetensors from network2 into traced_weight separated model
tensors = {}
with safe_open(safetensor_path,framework="pt") as f:
    for k in f.keys():
        tensors[k] = f.get_tensor(k)

# replace with weights from second network
torch_neuronx.replace_weights(weight_separated_trace,tensors)

# get outputs from neuron and cpu networks
out_network2 = network2(inp)
out_neuron = weight_separated_trace(inp)

# check that they are equal
print(out_network2,out_neuron)

Note

For non-safetensors models, use torch.load to load the model, and pass the model’s state_dict inside like the first example.

This document is relevant for: Inf2, Trn1, Trn2