This document is relevant for: Inf2, Trn1, Trn2

Training APIs#

Neuronx-Distributed Training APIs:#

In Neuronx-Distributed, we provide a series of APIs under neuronx_distributed directly that helps user to apply optimizations in NxD Core easily. These APIs cover configuration, model/optimizer initialization and saving/loading checkpoint.

Initialize NxD Core config:#

def neuronx_distributed.trainer.neuronx_distributed_config(
    tensor_parallel_size=1,
    pipeline_parallel_size=1,
    pipeline_config=None,
    optimizer_config=None,
    activation_checkpoint_config=None,
    pad_model=False,
    sequence_parallel=False,
    model_init_config=None,
    lora_config=None,
)

This method initializes NxD Core training config and initialize model parallel. This config maintains all optimization options of the distributed training, and it’s a global config (the same for all processes).

Parameters:

  • tensor_parallel_size (int) : Tensor model parallel size. Default: 1.

  • pipeline_parallel_size (int) : Pipeline model parallel size. Default: 1.

  • pipeline_config (dict)Pipeline parallel config. For details please refer to

    pipeline parallel guidance. Default: None.

  • optimizer_config (dict) : Optimizer config. Default: {"zero_one_enabled": False, "grad_clipping": True, "max_grad_norm": 1.0}.

  • Enable ZeRO-1 by setting zero_one_enabled to True.

  • Enable grad clipping by setting grad_clipping to True.

  • Change maximum grad norm value by setting max_grad_norm.

  • activation_checkpoint_config (str of torch.nn.Module)Activation checkpoint config,

    accept value: "full", None, or any torch.nn.Module. When set to full, regular activation checkpoint enabled (every transformer layer will be re-computed). When set to None, activation checkpoint disabled. When set to any torch.nn.Module, selective activation checkpoint enabled, any provided module will be re-computed. Default: None.

  • pad_model (bool) : Whether to pad attention heads of model. Default: False.

  • sequence_parallel (bool) : Whether to enable sequence parallel. Default: False.

  • model_init_config (dict) : Model initialization config. Default: {"sequential_move_factor": 11, "meta_device_init": False, "param_init_fn": None}.

  • lora_config: LoRA configuration. Default: None with LoRA disabled.

  • sequential_move_factor: num of processes instantiating model on host at the same time.

    This is done to avoid the host OOM. Range: 1-32.

  • meta_device_init: whether to initialize model on meta device.

  • param_init_fn: method that initialize parameters of modules, should be provided when

    param_init_fn is True.

Initialize NxD Core Model Wrapper:#

def neuronx_distributed.trainer.initialize_parallel_model(nxd_config, model_fn, *model_args, **model_kwargs)

This method initialize NxD Core model wrapper, return a wrapped model that can be used as a regular torch.nn.Module, while has all the model optimizations in config applied. This wrapper is designed to hide the complexity of optimizations such as pipeline model parallel, so that users can simply use the wrapped model as the unwrapped version.

Parameters:

  • nxd_config (dict): config generated by neuronx_distributed_config.

  • model_fn (callable): user provided function to get the model for training.

  • model_args and model_kwargs: arguments that will be passed to model_fn.

Model wrapper class and its methods:

class neuronx_distributed.trainer.model.NxDModel(torch.nn.Module):
    def local_module(self):
        # return the unwrapped local module

    def run_train(self, *args, **kwargs):
        # method to run one iteration, when pipeline parallel enabled,
        # user have to use this instead of forward+backward

    def named_parameters(self, *args, **kwargs):
        # only return parameters on local rank.
        # same for `parameters`, `named_buffers`, `buffers`

    def named_modules(self, *args, **kwargs):
        # only return modules on local rank.
        # same for `modules`, `named_children`, `children`

Note

As a short cut, users can call model.config or model.dtype from wrapped model if original model is hugging face transformers pre-trained model.

Initialize NxD Core Optimizer Wrapper:#

def neuronx_distributed.trainer.initialize_parallel_optimizer(nxd_config, optimizer_class, parameters, **defaults)

This method initialize NxD Core optimizer wrapper, return a wrapped optimizer that can be used as a regular torch.optim.Optimizer, while has all the optimizer optimizations in config applied.

This optimizer wrapper is inherited from toch.optim.Optimizer. It takes in the nxd_config and configures the optimizer to work with different distributed training regime.

The step method of the wrapped optimizer contains necessary all-reduce operations and grad clipping. Other methods and variables work the same as the unwrapped optimizer.

Parameters:

  • nxd_config (dict): config generated by neuronx_distributed_config.

  • optimizer_class (Type[torch.optim.Optimizer]): optimizer class to create the optimizer.

  • parameters (iterable): parameters passed to the optimizer.

  • defaults: optimizer options that will be passed to the optimizer.

Enable LoRA fine-tuning:#

LoRA model wrapper

class LoRAModel(module, LoraConfig)

Parameters:

  • module (torch.nn.Module): Module to be wrapped with LoRA

  • LoraConfig: The LoRA configuration defined in neuronx_distributed.modules.lora.LoraConfig

The flags in LoraConfig to initialize LoRA adapter:

  • enable_lora (bool): Enable LoRA fine-tuning.

  • lora_rank (int): The rank of LoRA adapter. A small LoRA rank reduces the memory footprint during fine-tuning, but it may harm the model quality.

  • lora_alpha (float): The alpha parameter for LoRA scaling, i.e., scaling LoRA weights against base model weights.

  • lora_dropout (float): The dropout probability for LoRA layers.

  • bias (str): Bias type for LoRA. Can be none, all or lora_only.

  • target_modules (List[str]): The names of the modules that need LoRA.

  • use_rslora (bool): If True, uses Rank-Stabilized LoRA, which sets the adapter scaling factor to lora_alpha/math.sqrt(lora_rank).

  • init_lora_weights (str): Weights initialization of LoRA adapter. Can be default (initialized with torch.nn.init.kaiming_uniform_()) or gaussian (initialized with torch.nn.init.normal_()).

Usage:

We first define the LoRA configuration for fine-tuning. Suppose the target modules is [q_proj, v_proj, k_proj], it indicates that LoRA will be appied to modules whose name includes any of the keywords. An example is

lora_config = neuronx_distributed.modules.lora.LoraConfig(
   enable_lora=True,
   lora_rank=16,
   lora_alpha=32,
   lora_dropout=0.05,
   bias="none",
   target_modules=["q_proj", "v_proj", "k_proj"],
)

You can enable LoRA fine-tuning like below

nxd_config = neuronx_distributed.neuronx_distributed_config(
  ...
  lora_config=lora_config,
)
model = neuronx_distributed.initialize_parallel_model(nxd_config, ...)

Then the NxD model will be initialized with LoRA adapter enabled.

Save Checkpoint:#

Method to save checkpoint, return None.

This method saves checkpoints for model, optimizer, scheduler and user contents sequentially. Model states are saved on data parallel rank-0 only. When ZeRO-1 optimizer is not turned on, optimizer states are also saved like this; while when ZeRO-1 optimizer is turned on, states are saved on all ranks. Scheduler and user contents are saved on master rank only. Besides, users can use use_xser=True to boost saving performance and avoid host OOM. It’s achieved by saving tensors one by one simultaneously and keeping the original data structure. However, the resulted checkpoint cannot be loaded using load api of PyTorch. Users can also use async_save=True to further boost saving performance. It’s achieved by saving tensors in separate processes along with computation. Setting async_save to true will result in more host memory being used, therefore increase the risk of application crash due to system ran out of memory.

def neuronx_distributed.trainer.save_checkpoint(
    path,
    tag="",
    model=None,
    optimizer=None,
    scheduler=None,
    user_content=None,
    num_workers=8,
    use_xser=False,
    num_kept_ckpts=None,
    async_save=False,
    zero1_optimizer=False
)

Parameters:

  • path (str): path to save the checkpoints.

  • tag (str): tag to save the checkpoints.

  • model (torch.nn.Module): model to save, optional.

  • optimizer (torch.optim.Optimizer): optimizer to save, optional.

  • scheduler: scheduler to save, optional.

  • user_content: user contents to save, optional.

  • num_workers (int): num of processes saving data on host at the same time.

    This is done to avoid the host OOM, range: 1-32.

  • use_xser (bool): whether to use torch-xla serialization. When enabled, num_workers

    will be ignored and maximum num of workers will be used. Default: False.

  • num_kept_ckpts (int): number of checkpoints to keep on disk, optional. Default: None.

  • async_save (bool): whether to use asynchronous saving method. Default: False.

  • zero1_optimizer (bool): : whether the optimizer state is from a zero1 optimizer, used when optimizer is a dict

Save LoRA Checkpoint:

NxD also uses neuronx_distributed.trainer.save_checkpoint() to save LoRA models, but it can set save_lora_base and merge_lora in LoraConfig to specify how to save LoRA checkpoint. There are three modes for LoRA checkpoint saving:

  • save_lora_base=False, merge_lora=False: Save the LoRA adapter only.

  • save_lora_base=True, merge_lora=False: Save both the base model and the LoRA adapter seperately.

  • save_lora_base=True, merge_lora=True: Merge the LoRA adapter into the base model and then save the base model.

Other than the adapter, NxD also needs to save the LoRA configuration file for LoRA loading. The configuration can be saved into the same checkpoint with the adapter, or saved as a seperately json file.

  • save_lora_config_adapter (bool): If False, save the configuration file as a seperately json file.

Note that if LoRA configuration file is saved separately, it is named as lora_adapter/adapter_config.json.

A configuration example to save the LoRA adapter only is

lora_config = neuronx_distributed.modules.lora.LoraConfig(
   ...
   save_lora_base=False,
   merge_lora=False,
   save_lora_config_adapter=True,
)

Load Checkpoint:#

Method to load checkpoint saved by save_checkpoint, return user contents if exists otherwise None. If tag not provided, will try to use the newest tag tracked by save_checkpoint.

Note that the checkpoint to be loaded must have the same model parallel degrees as in current use, and if ZeRO-1 optimizer is used, must use the same data parallel degrees.

def neuronx_distributed.trainer.load_checkpoint(
    path,
    tag=None,
    model=None,
    optimizer=None,
    scheduler=None,
    num_workers=8,
    strict=True,
)

Parameters:

  • path (str): path to load the checkpoints.

  • tag (str): tag to load the checkpoints.

  • model (torch.nn.Module): model to load, optional.

  • optimizer (torch.optim.Optimizer): optimizer to load, optional.

  • scheduler: scheduler to load, optional.

  • num_workers (int): num of processes loading data on host at the same time.

This is done to avoid the host OOM, range: 1-32. - strict (bool): whether to use strict mode when loading model checkpoint. Default: True.

Load LoRA Checkpoint:

NxD loads LoRA checkpoints by setting flags in LoraConfig.

  • load_lora_from_ckpt: Resumes the checkpoint process.

  • lora_save_dir: Load LoRA checkpoint from the specified folder

  • lora_load_tag: Load the LoRA checkpoint with the specified tag

An example is:

lora_config = LoraConfig(
   enable_lora=True,
   load_lora_from_ckpt=True,
   lora_save_dir=checkpoint_dir,  # checkpoint path
   lora_load_tag=tag,  # sub-directory under checkpoint path
)
nxd_config = nxd.neuronx_distributed_config(
   ...
   lora_config=lora_config,
)
model = nxd.initialize_parallel_model(nxd_config, ...)

The NxD model with be initialized with LoRA enabled and LoRA weights loaded. LoRA-related configurations are the same as the LoRA adapter checkpoint.

Sample usage:

import neuronx_distributed as nxd

# create config
nxd_config = nxd.neuronx_distributed_config(
    tensor_parallel_size=8,
    optimizer_config={"zero_one_enabled": True, "grad_clipping": True, "max_grad_norm": 1.0},
)

# wrap model
model = nxd.initialize_parallel_model(nxd_config, get_model)

# wrap optimizer
optimizer = nxd.initialize_parallel_optimizer(nxd_config, AdamW, model.parameters(), lr=1e-3)

...
(training loop):
    loss = model.run_train(inputs)
    optimizer.step()

...
# loading checkpoint (auto-resume)
user_content = nxd.load_checkpoint(
    "ckpts",
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
)
...
# saving checkpoint
nxd.save_checkpoint(
    "ckpts",
    nxd_config=nxd_config,
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    user_content={"total_steps": total_steps},
)

Modules:#

GQA-QKV Linear Module:#

class neuronx_distributed.modules.qkv_linear.GQAQKVColumnParallelLinear(
    input_size, output_size, bias=True, gather_output=True,
    sequence_parallel_enabled=False, dtype=torch.float32, device=None, kv_size_multiplier=1, fuse_qkv=True)

This module parallelizes the Q,K,V linear projections using ColumnParallelLinear layers. Instead of using 3 different linear layers, we can replace it with a single QKV module. In case of GQA module, the number of Q attention heads are N times more than the number of K and V attention heads. The K and V attention heads are replicated after projection to match the number of Q attention heads. This helps to reduce the K and V weights and is useful especially during inference. However, in case of training these modules, it restricts the tensor-parallel degree that can be used, since the attention heads should be divisible by tensor-parallel degree. Hence, to mitigate this bottleneck, the GQAQKVColumnParallelLinear takes in a kv_size_multiplier argument. The module would replicate the K and V weights kv_size_multiplier times thereby allowing you to use higher tensor-parallel degree. Note: here instead of replicating the projection N/tp_degree times, we end of replicating the weights kv_size_multiplier times. This would produce the same result, allow you to use higher tp_degree degree, however, it would result in extra memory getting consumed.

Parameters:

  • input_size: (int) : First dimension of the weight matrix

  • output_sizes: (List[int]) : A list of second dimension of the Q and K/V weight matrix

  • bias: (bool): If set to True, bias would be added

  • gather_output: (bool)If true, call all-gather on output and make Y available to all

    Neuron devices, otherwise, every Neuron device will have its output which is Y_i = XA_i

  • sequence_parallel_enabled: (bool)When sequence-parallel is enabled, it would gather

    the inputs from the sequence parallel region and perform the forward and backward passes

  • init_method: (torch.nn.init) : Initialization function for the Q and K/V weights.

  • dtype: (dtype) : Datatype for the weights

  • device: (torch.device)Device to initialize the weights on. By default, the weights

    would be initialized on CPU

  • kv_size_multiplier: (int): Factor by which the K and V weights would be replicated along the first dimension

  • fuse_qkv: (bool): When fuse_qkv is enabled, a single fused tensor is used for QKV. By default, this parameter is True.

Checkpointing:#

These are set of APIs for saving and loading the checkpoint. These APIs take care of saving and loading the shard depending the tensor parallel rank of the worker.

Save Checkpoint:#

def neuronx_distributed.parallel_layers.save(state_dict, save_dir, save_serially=True, save_xser: bool=False, down_cast_bf16=False)

Note

This method will be deprecated, use neuronx_distributed.trainer.save_checkpoint instead.

This API will save the model from each tensor-parallel rank in the save_dir . Only workers with data parallel rank equal to 0 would be saving the checkpoints. Each tensor parallel rank would be creating a tp_rank_ii_pp_rank_ii folder inside save_dir and each ones saves its shard in the tp_rank_ii_pp_rank_ii folder. If save_xser is enabled, the folder name would be tp_rank_ii_pp_rank_ii.tensors and there will be a ref data file named as tp_rank_ii_pp_rank_ii in save_dir for each rank.

Parameters:

  • state_dict: (dict) : Model state dict. Its the same dict that you would save using torch.save

  • save_dir: (str) : Model save directory.

  • save_serially: (bool): This flag would save checkpoints one model-parallel rank at a time. This is particularly useful when we are checkpointing large models.

  • save_xser: (bool): This flag would save the model with torch xla serialization. This could significantly reduce checkpoint saving time when checkpointing large model, so it’s recommended to enable xser when the model is large. Note that if a checkpoint is saved with save_xser, it needs to be loaded with load_xser, vice versa.

  • down_cast_bf16: (bool): This flag would downcast the state_dict to bf16 before saving.

Load Checkpoint#

def neuronx_distributed.parallel_layers.load(
    load_dir, model_or_optimizer=None, model_key='model', load_xser=False, sharded=True)

Note

This method will be deprecated, use neuronx_distributed.trainer.load_checkpoint instead.

This API will automatically load checkpoint depending on the tensor parallel rank. For large models, one should pass the model object to the load API to load the weights directly into the model. This could avoid host OOM, as the load API would load the checkpoints for one tensor parallel rank at a time.

Parameters:

  • load_dir: (str) : Directory where the checkpoint is saved.

  • model_or_optimizer: (torch.nn.Module or torch.optim.Optimizer): Model or Optimizer object.

  • model: (torch.nn.Module or torch.optim.Optimizer): Model or Optimizer object, equivilant to model_or_optimizer

  • model_key: (str) : The model key used when saving the model in the state_dict.

  • load_xser: (bool) : Load model with torch xla serialization. Note that if a checkpoint is saved with save_xser, it needs to be loaded with load_xser, vice versa.

  • sharded: (bool) : If the checkpoint is not sharded, pass False. This is useful (especially during inference) when the model is trained using a different strategy and you end up saving a single unsharded checkpoint. You can then load this unsharded checkpoint onto the sharded model. When this attribute is set to False , it is necessary to pass the model object. Note: The keys in the state-dict should have the same name as in the model object, else it would raise an error.

Gradient Clipping:#

With tensor parallelism, we need to handle the gradient clipping as we have to accumulate the total norm from all the tensor parallel ranks. This should be handled by the following API

def neuronx_distributed.parallel_layers.clip_grad_norm(
    parameters, max_norm, norm_type=2)

Parameters:

  • parameters (Iterable[Tensor] or Tensor) : an iterable of Tensors or a single Tensor that will have gradients normalized

  • max_norm (float or int) :max norm of the gradients

  • norm_type (float or int) : type of the used p-norm. Can be ‘inf’ for infinity norm.

Neuron Zero1 Optimizer:#

In Neuronx-Distributed, we built a wrapper on the Zero1-Optimizer present in torch-xla.

class NeuronZero1Optimizer(Zero1Optimizer)

This wrapper takes into account the tensor-parallel degree and computes the grad-norm accordingly. It also provides two APIs: save_sharded_state_dict and load_sharded_state_dict. As the size of the model grows, saving the optimizer state from a single rank can result in OOMs. Hence, the api to save_sharded_state_dict can allow saving states from each data-parallel rank. To load this sharded optimizer state, there is a corresponding load_sharded_state_dict that allows each rank to pick its corresponding shard from the checkpoint directory.

optimizer_grouped_parameters = [
     {
         "params": [
             p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
         ],
         "weight_decay": 0.01,
     },
     {
         "params": [
             p for n, p in param_optimizer if any(nd in n for nd in no_decay)
         ],
         "weight_decay": 0.0,
     },
]

optimizer = NeuronZero1Optimizer(
     optimizer_grouped_parameters,
     AdamW,
     lr=flags.lr,
     pin_layout=False,
     sharding_groups=parallel_state.get_data_parallel_group(as_list=True),
     grad_norm_groups=parallel_state.get_tensor_model_parallel_group(as_list=True),
 )

The interface is same as Zero1Optimizer in torch-xla

save_sharded_state_dict(output_dir, save_serially = True)

Note

This method will be deprecated, use neuronx_distributed.trainer.save_checkpoint instead.

Parameters:

  • output_dir (str) : Checkpoint directory where the sharded optimizer states need to be saved

  • save_serially (bool)Whether to save the states one data-parallel rank at a time. This is

    especially useful when we want to checkpoint large models.

load_sharded_state_dict(output_dir, num_workers_per_step = 8)

Note

This method will be deprecated, use neuronx_distributed.trainer.load_checkpoint instead.

Parameters:

  • output_dir (str) : Checkpoint directory where the sharded optimizer states are saved

  • num_workers_per_step (int) : This argument controls how many workers are doing model load in parallel.

Neuron PyTorch-Lightning#

Neuron PyTorch-Lightning is currently based on Lightning version 2.1.0, and will eventually be upstreamed Lightning-AI code base

Neuron Lightning Module#

Inherited from LightningModule

class neuronx_distributed.lightning.NeuronLTModule(
    model_fn: Callable,
    nxd_config: Dict,
    opt_cls: Callable,
    scheduler_cls: Callable,
    model_args: Tuple = (),
    model_kwargs: Dict = {},
    opt_args: Tuple = (),
    opt_kwargs: Dict = {},
    scheduler_args: Tuple = (),
    scheduler_kwargs: Dict = {},
    grad_accum_steps: int = 1,
    log_rank0: bool = False,
    manual_opt: bool = True,
)

Parameters:

  • model_fn: Model function to create the actual model

  • nxd_config: Neuronx Distributed Config, output of neuronx_distributed.neuronx_distributed_config

  • opt_cls: Callable to create optimizer

  • scheduler_cls: Callable to create scheduler

  • model_args: Tuple of args fed to model callable

  • model_kwargs: Dict of keyworded args fed to model callable

  • opt_args: Tuple of args fed to optimizer callable

  • opt_kwargs: Dict of keyword args fed to optimizer callable

  • scheduler_args: Tuple of args fed to scheduler callable

  • scheduler_args: Dict of keyworded args fed to scheduler callable

  • grad_accum_steps: Grad accumulation steps

  • log_rank0: Log at rank 0 (by default it will log at the last PP rank). Note that setting this to True will introduce extra communication per step hence causing performance drop

  • manual_opt: Whether to do manual optimization, note that currently NeuronLTModule doesn’t support auto optimization so this should always set to True

Neuron XLA Strategy#

Inherited from XLAStrategy

class neuronx_distributed.lightning.NeuronXLAStrategy(
    nxd_config: Dict = None,
    tensor_parallel_size: int = 1,
    pipeline_parallel_size: int = 1,
    save_load_xser: bool = True,
)

Parameters:

  • nxd_config: Neuronx Distributed Config, output of neuronx_distributed.neuronx_distributed_config

  • tensor_parallel_size: Tensor parallel degree, only needed when nxd_config is not specified

  • pipeline_parallel_size: Pipeline parallel degree, only needed when nxd_config is not specified (Note that for now we only support TP with Neuron-PT-Lightning)

  • save_load_xser: Set to True will enable save/load with xla serialization, for more context check Save Checkpoint

Neuron XLA Precision Plugin#

Inherited from XLAPrecisionPlugin

class neuronx_distributed.lightning.NeuronXLAPrecisionPlugin

Neuron TQDM Progress Bar#

Inherited from TQDMProgressBar

class neuronx_distributed.lightning.NeuronTQDMProgressBar

Neuron TensorBoard Logger#

Inherited from TensorBoardLogger

class neuronx_distributed.lightning.NeuronTensorBoardLogger(save_dir)

Parameters:

  • save_dir: Directory to save the log files

This document is relevant for: Inf2, Trn1, Trn2