This document is relevant for: Inf2
, Trn1
, Trn2
PyTorch Neuron (torch-neuronx
) Weight Replacement API for Inference#
- torch_neuronx.replace_weights(neuron_model, weights)#
Replaces the weights in a Neuron Model with split weights. This function will emit a warning of the supplied Neuron model does not contain any separated weights.
Warning
The below API is only applicable for models traced with the parameter
inline_weights_to_neff=False
, which isTrue
by default. Seetorch_neuronx.trace()
for details.- Parameters:
neuron_model (RecursiveScriptModule) – A Neuron model compiled with split weights
weights (Module,Dict[str, Tensor]) – Either the original model with the new weights, or the state_dict of a model.
- Returns:
None
, this function performs the weight replacement inline.- Return type:
None
Examples
Using a model
import torch import torch_neuronx class Network(torch.nn.Module): def __init__(self, hidden_size=4, layers=3) -> None: super().__init__() self.layers = torch.nn.Sequential( *(torch.nn.Linear(hidden_size, hidden_size) for _ in range(layers))) def forward(self, tensor): return self.layers(tensor) # initialize two networks network = Network() network2 = Network() network.eval() network2.eval() inp = torch.rand(2,4) # trace weight separated model with first network weight_separated_trace = torch_neuronx.trace(network,inp,inline_weights_to_neff=False) # replace with weights from second network torch_neuronx.replace_weights(weight_separated_trace,network2.state_dict()) # get outputs from neuron and cpu networks out_network2 = network2(inp) out_neuron = weight_separated_trace(inp) # check that they are equal print(out_network2,out_neuron)
Using safetensors
The safetensors library is useful for storing/loading model tensors safely and quickly.
import torch import torch_neuronx from safetensors import safe_open from safetensors.torch import save_model class Network(torch.nn.Module): def __init__(self, hidden_size=4, layers=3) -> None: super().__init__() self.layers = torch.nn.Sequential( *(torch.nn.Linear(hidden_size, hidden_size) for _ in range(layers))) def forward(self, tensor): return self.layers(tensor) # initialize two networks network = Network() network2 = Network() network.eval() network2.eval() inp = torch.rand(2,4) # trace weight separated model with first network weight_separated_trace = torch_neuronx.trace(network,inp,inline_weights_to_neff=False) # save network2 weights to safetensors safetensor_path = f"{directory}/network2.safetensors" save_model(network2,safetensor_path) #load safetensors from network2 into traced_weight separated model tensors = {} with safe_open(safetensor_path,framework="pt") as f: for k in f.keys(): tensors[k] = f.get_tensor(k) # replace with weights from second network torch_neuronx.replace_weights(weight_separated_trace,tensors) # get outputs from neuron and cpu networks out_network2 = network2(inp) out_neuron = weight_separated_trace(inp) # check that they are equal print(out_network2,out_neuron)
Note
For non-safetensors models, use
torch.load
to load the model, and pass the model’sstate_dict
inside like the first example.
This document is relevant for: Inf2
, Trn1
, Trn2