This document is relevant for: Inf2
, Trn1
, Trn1n
Inference APIs#
Model Trace:#
We can use the tensor parallel layers to perform large model inference too. For performing inference, we can re-use the Parallel model built above for training and then use the trace APIs provided by the neuronx_distributed package to trace it for inference. One can use the following set of APIs for running distributed inference:
def neuronx_distributed.trace.parallel_model_trace(func, example_inputs, compiler_workdir=None, compiler_args=None, inline_weights_to_neff=True, bucket_config=None, tp_degree=1, max_parallel_compilations=None)
This API would launch tensor parallel workers, where each worker would trace its own model. These traced models would be wrapped with a single TensorParallelModel module which can then be used like any other traced model.
Parameters:
func : Callable
: This is a function that returns aModel
object and a dictionary of states. Theparallel_model_trace
API would call this function inside each worker and run trace against them. Note: This differs from thetorch_neuronx.trace
where thetorch_neuronx.trace
requires a model object to be passed.example_inputs: (torch.Tensor like)
: The inputs that needs to be passed to the model. If you are usingbucket_config
, then this must be a list of inputs for each bucket model. This configuration is similar totorch_neuronx.bucket_model_trace()
compiler_workdir: Optional[str,pathlib.Path]
: Work directory used by |neuronx-cc|. This can be useful for debugging and inspecting intermediary |neuronx-cc| outputs.compiler_args: Optional[Union[List[str],str]]
: List of strings representing |neuronx-cc| compiler arguments. See Neuron Compiler CLI Reference Guide (neuronx-cc) for more information about compiler options.inline_weights_to_neff: bool
: A boolean indicating whether the weights should be inlined to the NEFF. If set to False, weights will be separated from the NEFF. The default isTrue
.bucket_config: torch_neuronx.BucketModelConfig
: The config object that defines bucket selection behavior. Seetorch_neuronx.BucketModelConfig()
for more details.tp_degree: (int)
: How many devices to be used when performing tensor parallel shardingmax_parallel_compilations: Optional[int]
: If specified, this function will only trace these numbers of models in parallel, which can be necessary to prevent OOMs while tracing. The default is None, which means the number of parallel compilations is equal to thetp_degree
.
Trace Model Save/Load:#
Save:#
def neuronx_distributed.trace.parallel_model_save(model, save_dir)
This API should save the traced model in save_dir . Each shard would be saved in its respective directory inside the save_dir. Parameters:
model: (TensorParallelModel)
: Traced model produced using the
parallel_model_trace api.
- save_dir: (str)
: The directory where the model would be saved
Load:#
def neuronx_distributed.trace.parallel_model_load(load_dir)
This API will load the sharded traced model into TensorParallelModel
for inference.
Parameters:#
load_dir: (str)
: Directory which contains the traced model.
This document is relevant for: Inf2
, Trn1
, Trn1n