This document is relevant for: Inf2
, Trn1
, Trn1n
Distributed Strategies APIs#
NeuronX Distributed Core (NxD Core) is XLA based library for distributed training and inference on Neuron devices. As part of this library, we support 3D parallelism: Tensor-Parallelism, Pipeline-Parallelism and Data-Parallelism. We also support Zero1 optimizer to shard the optimizer weights. To support tensor-parallelism on Neuron, we adopted the Apex Library built for CUDA devices. We modified the implementations to work with XLA. This document enlist the different APIs and modules provided by the library
Parallel Model State:#
Initialize Model Parallelism:#
def neuronx_distributed.parallel_state.initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
)
This module would initialize the distributed model training and allows users to set the number of tensor_parallel world size.
Parameters:
tensor_model_parallel_size
: This should set the number of tensor parallel workers. Note the default value is set to 1pipeline_model_parallel_size
: This should set the number of pipeline parallel workers. Note the default value is set to 1
Other helper APIs:#
neuronx_distributed.parallel_state.get_data_parallel_size()
: Returns the data parallel world size depending on the number of global workers and tensor parallel workers.neuronx_distributed.parallel_state.get_tensor_model_parallel_size()
: Returns the tensor parallel world size.neuronx_distributed.parallel_state.get_tensor_model_parallel_rank()
: Returns the rank of the worker within the tensor parallel groupneuronx_distributed.parallel_state.get_pipeline_model_parallel_size()
: Returns the pipeline parallel world size.neuronx_distributed.parallel_state.get_pipeline_model_parallel_rank()
: Returns the rank of the worker within the pipeline parallel groupneuronx_distributed.parallel_state.get_data_parallel_rank()
: Returns the rank of the worker in the data parallel group.neuronx_distributed.parallel_state.get_data_parallel_group(as_list=False)
: Returns the data parallel group after taking into account the tensor parallel size and the global world size. as_list argument when set to True, would return the group as a List[List] otherwise it would return a torch.distributed.group.neuronx_distributed.parallel_state.get_tensor_model_parallel_group(as_list=False)
: Returns the tensor parallel group after taking into account the tensor parallel size and the global world size. as_list argument when set to True, would return the group as a List[List] otherwise it would return a torch.distributed.group.neuronx_distributed.parallel_state.get_pipeline_model_parallel_group(as_list=False)
: Returns the pipeline parallel group after taking into account the pipeline parallel size and the global world size. as_list argument when set to True, would return the group as a List[List] otherwise it would return a torch.distributed.group.move_model_to_device(model, device)
: This api moves the model to device by preserving tensor parallel attributes.
Parallel Layers:#
Majority of parameters within the transformer based model reside in the Embedding and Linear layers. Hence, to reduce the number of parameters on a single device because of these layers, we provided sharded Embedding and Linear layers.
Parallel Embedding:#
class neuronx_distributed.parallel_layers.ParallelEmbedding(
num_embeddings, embedding_dim, init_method=init.normal_,
dtype=torch.float32, device=None)
This module is intended to replace torch.nn.Embedding . In cases where the vocab size is too large, we can shard the Embedding table across workers. Note: The embedding table would be sharded across all the tensor-parallel workers.
Parameters:
num_embeddings (int)
: size of the dictionary of embeddingsembedding_dim (int)
: the size of each embedding vectorinit_method: (torch.nn.init)
: Initialization function for the embedding weights.dtype: (dtype)
: Datatype for the weightsdevice: (torch.device)
: Device to initialize the weights on. By default, the weights would be initialized on CPU
ColumnParallel Linear Layer:#
class neuronx_distributed.parallel_layers.ColumnParallelLinear(
input_size, output_size, bias=True, gather_output=True,
sequence_parallel_enabled=False, dtype=torch.float32, device=None)
This module would perform a Column wise partition of the weight matrix.
Linear layer is defined as Y = XA + b
, here A is parallelized along
second dimension as A = [A_1, A_2 .... A_p]
. Note
: This layer
is designed to operate on 3-dimensional inputs.
Parameters:
input_size: (int)
: First dimension of the weight matrixoutput_size: (int)
: Second dimension of the weight matrixbias: (bool)
: If set to True, bias would be addedgather_output: (bool)
: If true, call all-gather on output and make Y available to all Neuron devices, otherwise, every Neuron device will have its output which is Y_i = XA_isequence_parallel_enabled: (bool)
When sequence-parallel is enabled, it wouldgather the inputs from the sequence parallel region and perform the forward and backward passes
dtype: (dtype)
: Datatype for the weightsdevice: (torch.device)
: Device to initialize the weights on. By default, the weights would be initialized on CPU
RowParallel Linear Layer:#
class neuronx_distributed.parallel_layers.RowParallelLinear(
input_size, output_size, bias=True, input_is_parallel=False,
sequence_parallel_enabled=False, dtype=torch.float32, device=False
)
The linear layer is defined as Y = XA + b
. A is parallelized along
its first dimension and X along its second. Note
: This layer is
designed to operate on 3-dimensional inputs.
Parameters:
input_size: (int)
: First dimension of the weight matrixoutput_size: (int)
: Second dimension of the weight matrixbias: (bool)
: If set to True, bias would be addedinput_is_parallel: (bool)
: If true, we assume that the input is already split across the Neuron devices and we do not split again. This is useful when we have a ColumnParallel Layer just before the Row Parallel layersequence_parallel_enabled: (bool)
: When sequence-parallel is enabled, it would gather the inputs from the sequence parallel region and perform the forward and backward passesdtype: (dtype)
: Datatype for the weightsdevice: (torch.device)
: Device to initialize the weights on. By default, the weights would be initialized on CPU
Padding Tensor-Parallel Layers#
def neuronx_distributed.parallel_layers.pad.pad_model(
model, tp_degree, n_heads, wrapped_classes=(), pad_hook_fn=None)
Pads a generic model to function to a desired tensor parallelism degree by padding the number of attention heads. Returns the original model modified with padding. Uses 1-axis padding strategy: pads the sharded dim of the ParallelLinear layers to the size it would have been for the padded number of heads.
Parameters:
model (torch.nn.Module)
: model to be paddedtp_degree (int)
: tensor parallel degreen_heads (int)
the number of heads the given model to be padded has. This cantypically be found in the config
wrapped_classes (Tuple[any], *optional*, defaults to `()`)
tuple of classes(and their submodules) which should be padded
pad_hook_fn (Callable[any, float], *optional*, defaults to `None`)
a hookfunction that is called whenever encountering a class to pad. Receives an instance of the class to pad and the tgt_src_ratio (num_heads_padded / num_heads)as its argument
- Usage:
When modifying the Attention layer, typically you must divide by TP degree like so:
self.num_heads = neuronx_dist_utils.divide(self.num_heads, get_tensor_model_parallel_size())
This line must be modified like so:
self.num_heads = neuronx_dist_utils.divide( self.num_heads + get_number_of_extra_heads(self.num_heads, get_tensor_model_parallel_size()), get_tensor_model_parallel_size())
Then, after initializing the model, you must call this wrapper:
model = get_model(config=desired_config) model = pad_model(model, tp_degree=32, desired_config.num_heads) # Use the model as desired after this point
You can specify a specific layer or class for your model to pad, so you aren’t unnecessarily padding. Typically, this layer will be your Attention layer
model = pad_model(model, tp_degree=32, desired_config.num_heads, wrapped_classes=[MyAttention])
You can also specify a pad_hook_fn, to be called whenever encountering an instance of wrapped_class, passing in said instance as a parameter, along with the tgt_src_ratio (num_heads_padded / num_heads).
def my_hook(attention_to_pad, tgt_src_ratio): attention_to_pad.split_size = int(model.split_size * tgt_src_ratio) model = pad_model( model, tp_degree=32, desired_config.num_heads, wrapped_classes=[MyAttention], pad_hook_fn=my_hook )
Loss functions:#
When you shard the final MLP layer using tensor-parallelism, instead of recollecting all the outputs from each TP rank, we can use the ParallelCrossEntropy loss function. This function would take the parallel logits produced by final parallel MLP and produce a loss by taking into account that the logits are sharded across multiple workers.
def neuronx_distributed.parallel_layers.loss_functions.parallel_cross_entropy(
parallel_logits, labels, label_smoothing=0.0)
Parameters:
parallel_logits (Tensor)
: Sharded logits from the previous MLPlabels (Tensor)
: Label for each token. Labels should not be sharded, and the parallel_cross_entropy would take care of sharding the labels internallylabel_smoothing (float)
: A float in [0.0, 1.0]. Specifies the amount of smoothing when computing the loss, where 0.0 means no smoothing
Pipeline parallelism:#
Neuron Distributed Pipeline Model#
class NxDPPModel(
module: torch.nn.Module,
transformer_layer_cls: Optional[Any] = None,
num_microbatches: int = 1,
virtual_pipeline_size: int = 1,
output_loss_value_spec: Optional[Union[Dict, Tuple]] = None,
return_mb_loss: bool = False,
broadcast_and_average_loss: bool = False,
pipeline_cuts: Optional[List[str]] = None,
input_names: Optional[List[str]] = None,
leaf_module_cls: Optional[List[Any]] = None,
autowrap_functions: Optional[Tuple[ModuleType]] = None,
autowrap_modules: Optional[Tuple[Callable, ...]] = None,
tracer_cls: Optional[Union[str, Any]] = None,
param_init_fn: Optional[Any] = None,
trace_file_path: Optional[str] = None,
use_zero1_optimizer: bool = False,
auto_partition: Optional[bool] = False,
deallocate_pipeline_outputs: bool = False,
)
Parameters:
module
: Module to be distributed with pipeline parallelismtransformer_layer_cls
: The module class of transformer layersnum_microbatches
: Number of pipeline microbatchsvirtual_pipeline_size
: Virtual pipeline size if greater than 1 we will use the interleaved pipeline schedule.output_loss_value_spec
:The
output_loss_value_spec
value can be specified to disambiguate which value in the output of forward is the loss value on which NxDPPModel should apply backpropagation. For example, if yourforward
returns a tuple(loss, model_out)
, you can specifyoutput_loss_value_spec=(True, False)
. Or, if yourforward
returns a dict{'loss': loss_value, 'model_out': model_out}
, you can specifyoutput_loss_value_spec={'loss': True, 'model_out': False}
referred from this
return_mb_loss
: Whether return a list of loss for all microbatchsbroadcast_and_average_loss
:Whether to broadcast loss to all PP ranks and average across dp ranks, when set to True return_mb_loss must be Falsepipeline_cuts
: A list of layer names that will be used to annotate pipeline stage boundariesinput_names
:The input names that will be used for tracing, which will be the same as the model inputs during runtime.leaf_module_cls
:A list of module classes that should be treated as leaf nodes during tracing. Note transformer layer class will be by default treat as leaf nodes.autowrap_modules
: (symbolic tracing only)Python modules whose functions should be wrapped automatically without needing to use fx.wrap(). reference here
autowrap_functions
: (symbolic tracing only)Python functions that should be wrapped automatically without needing to use fx.wrap(). reference here
tracer_cls
:User provided tracer class for symbolic tracing. It can be “hf”, “torch” or any tracer class user created.param_init_fn
:Function used to initialize parameters. This is useful if user wants to use meta device to do delayed parameter initialization. param_init_fn should take a module as input and initialize the parameters that belongs to this module only (not for submodules).
use_zero1_optimizer
: Whether to use the zero1 optimizer. When setting to True the gradient average will be handed over.auto_partition
:Boolean to indicate whether to use auto_partition for the model. When set to True, the pipeline cuts used as the pipeline stage boundaries to partition the model are automatically determined. When set to True, the pipeline_cuts parameter should not be set. The pipeline_cuts are chosen on the basis of the transformer layer names.
deallocate_pipeline_outputs
:Whether to deallocate the pipeline outputs after send. After send the output tensor is only useful for its ‘.grad_fn’ field, and not its ‘.data’.
Common used APIs#
NxDPPModel.run_train(**kwargs)
Train the model with PP schedule, which will run both forward and backward in a PP manner. The kwargs should be the same as the input_names provided to the trace function. Will output the loss that provided by user from output_loss_value_spec.
NxDPPModel.run_eval(**kwargs)
Eval the model with PP schedule, which will run forward only. The kwargs should be the same as the input_names provided to the trace function. Will output the loss that provided by user from output_loss_value_spec.
NxDPPModel.local_named_parameters(**kwargs)
The parameters that are local to this PP rank. This must be called after the model is partitioned.
NxDPPModel.local_named_modules(**kwargs)
This document is relevant for: Inf2
, Trn1
, Trn1n