This document is relevant for: Inf2
, Trn1
, Trn1n
Training APIs#
Neuronx-Distributed Training APIs:#
In Neuronx-Distributed, we provide a series of APIs under neuronx_distributed directly that helps user to apply optimizations in NxD Core easily. These APIs cover configuration, model/optimizer initialization and saving/loading checkpoint.
Initialize NxD Core config:#
def neuronx_distributed.trainer.neuronx_distributed_config(
tensor_parallel_size=1,
pipeline_parallel_size=1,
pipeline_config=None,
optimizer_config=None,
activation_checkpoint_config=None,
pad_model=False,
sequence_parallel=False,
model_init_config=None,
lora_config=None,
)
This method initializes NxD Core training config and initialize model parallel. This config maintains all optimization options of the distributed training, and it’s a global config (the same for all processes).
Parameters:
tensor_parallel_size (int)
: Tensor model parallel size. Default:1
.pipeline_parallel_size (int)
: Pipeline model parallel size. Default:1
.pipeline_config (dict)
Pipeline parallel config. For details please refer topipeline parallel guidance. Default:
None
.
optimizer_config (dict)
: Optimizer config. Default:{"zero_one_enabled": False, "grad_clipping": True, "max_grad_norm": 1.0}
.Enable ZeRO-1 by setting
zero_one_enabled
toTrue
.Enable grad clipping by setting
grad_clipping
toTrue
.Change maximum grad norm value by setting
max_grad_norm
.activation_checkpoint_config (str of torch.nn.Module)
Activation checkpoint config,accept value:
"full"
,None
, or anytorch.nn.Module
. When set tofull
, regular activation checkpoint enabled (every transformer layer will be re-computed). When set toNone
, activation checkpoint disabled. When set to anytorch.nn.Module
, selective activation checkpoint enabled, any provided module will be re-computed. Default:None
.
pad_model (bool)
: Whether to pad attention heads of model. Default:False
.sequence_parallel (bool)
: Whether to enable sequence parallel. Default:False
.model_init_config (dict)
: Model initialization config. Default:{"sequential_move_factor": 11, "meta_device_init": False, "param_init_fn": None}
.lora_config
: LoRA configuration. Default:None
with LoRA disabled.sequential_move_factor
: num of processes instantiating model on host at the same time.This is done to avoid the host OOM. Range: 1-32.
meta_device_init
: whether to initialize model on meta device.param_init_fn
: method that initialize parameters of modules, should be provided whenparam_init_fn
isTrue
.
Initialize NxD Core Model Wrapper:#
def neuronx_distributed.trainer.initialize_parallel_model(nxd_config, model_fn, *model_args, **model_kwargs)
This method initialize NxD Core model wrapper, return a wrapped model that can be used as
a regular torch.nn.Module
, while has all the model optimizations in config applied.
This wrapper is designed to hide the complexity of optimizations such as pipeline model
parallel, so that users can simply use the wrapped model as the unwrapped version.
Parameters:
nxd_config (dict)
: config generated byneuronx_distributed_config
.model_fn (callable)
: user provided function to get the model for training.model_args
andmodel_kwargs
: arguments that will be passed tomodel_fn
.
Model wrapper class and its methods:
class neuronx_distributed.trainer.model.NxDModel(torch.nn.Module):
def local_module(self):
# return the unwrapped local module
def run_train(self, *args, **kwargs):
# method to run one iteration, when pipeline parallel enabled,
# user have to use this instead of forward+backward
def named_parameters(self, *args, **kwargs):
# only return parameters on local rank.
# same for `parameters`, `named_buffers`, `buffers`
def named_modules(self, *args, **kwargs):
# only return modules on local rank.
# same for `modules`, `named_children`, `children`
Note
As a short cut, users can call model.config
or model.dtype
from wrapped model
if original model is hugging face transformers pre-trained model.
Initialize NxD Core Optimizer Wrapper:#
def neuronx_distributed.trainer.initialize_parallel_optimizer(nxd_config, optimizer_class, parameters, **defaults)
This method initialize NxD Core optimizer wrapper, return a wrapped optimizer that can be used as
a regular torch.optim.Optimizer
, while has all the optimizer optimizations in config applied.
This optimizer wrapper is inherited from toch.optim.Optimizer
. It takes in the nxd_config
and
configures the optimizer to work with different distributed training regime.
The step method of the wrapped optimizer contains necessary all-reduce operations and grad clipping. Other methods and variables work the same as the unwrapped optimizer.
Parameters:
nxd_config (dict)
: config generated byneuronx_distributed_config
.optimizer_class (Type[torch.optim.Optimizer])
: optimizer class to create the optimizer.parameters (iterable)
: parameters passed to the optimizer.defaults
: optimizer options that will be passed to the optimizer.
Enable LoRA fine-tuning:#
LoRA model wrapper
class LoRAModel(module, LoraConfig)
Parameters:
module (torch.nn.Module)
: Module to be wrapped with LoRALoraConfig
: The LoRA configuration defined inneuronx_distributed.modules.lora.LoraConfig
The flags in LoraConfig
to initialize LoRA adapter:
enable_lora (bool)
: Enable LoRA fine-tuning.lora_rank (int)
: The rank of LoRA adapter. A small LoRA rank reduces the memory footprint during fine-tuning, but it may harm the model quality.lora_alpha (float)
: The alpha parameter for LoRA scaling, i.e., scaling LoRA weights against base model weights.lora_dropout (float)
: The dropout probability for LoRA layers.bias (str)
: Bias type for LoRA. Can benone
,all
orlora_only
.target_modules (List[str])
: The names of the modules that need LoRA.use_rslora (bool)
: If True, uses Rank-Stabilized LoRA, which sets the adapter scaling factor tolora_alpha/math.sqrt(lora_rank)
.init_lora_weights (str)
: Weights initialization of LoRA adapter. Can bedefault
(initialized withtorch.nn.init.kaiming_uniform_()
) orgaussian
(initialized withtorch.nn.init.normal_()
).
Usage:
We first define the LoRA configuration for fine-tuning. Suppose the target modules is
[q_proj, v_proj, k_proj]
, it indicates that LoRA will be appied to modules whose name includes any of the keywords. An example islora_config = neuronx_distributed.modules.lora.LoraConfig( enable_lora=True, lora_rank=16, lora_alpha=32, lora_dropout=0.05, bias="none", target_modules=["q_proj", "v_proj", "k_proj"], )You can enable LoRA fine-tuning like below
nxd_config = neuronx_distributed.neuronx_distributed_config( ... lora_config=lora_config, ) model = neuronx_distributed.initialize_parallel_model(nxd_config, ...)Then the NxD model will be initialized with LoRA adapter enabled.
Save Checkpoint:#
Method to save checkpoint, return None
.
This method saves checkpoints for model, optimizer, scheduler and user contents sequentially.
Model states are saved on data parallel rank-0 only. When ZeRO-1 optimizer is not turned on,
optimizer states are also saved like this; while when ZeRO-1 optimizer is turned on, states
are saved on all ranks. Scheduler and user contents are saved on master rank only. Besides,
users can use use_xser=True
to boost saving performance and avoid host OOM. It’s achieved
by saving tensors one by one simultaneously and keeping the original data structure.
However, the resulted checkpoint cannot be loaded using load
api of PyTorch. Users
can also use async_save=True
to further boost saving performance. It’s achieved by saving tensors
in separate processes along with computation. Setting async_save
to true will result
in more host memory being used, therefore increase the risk of application crash due to system
ran out of memory.
def neuronx_distributed.trainer.save_checkpoint(
path,
tag="",
model=None,
optimizer=None,
scheduler=None,
user_content=None,
num_workers=8,
use_xser=False,
num_kept_ckpts=None,
async_save=False,
zero1_optimizer=False
)
Parameters:
path (str)
: path to save the checkpoints.tag (str)
: tag to save the checkpoints.model (torch.nn.Module)
: model to save, optional.optimizer (torch.optim.Optimizer)
: optimizer to save, optional.scheduler
: scheduler to save, optional.user_content
: user contents to save, optional.num_workers (int)
: num of processes saving data on host at the same time.This is done to avoid the host OOM, range: 1-32.
use_xser (bool)
: whether to use torch-xla serialization. When enabled,num_workers
will be ignored and maximum num of workers will be used. Default:
False
.
num_kept_ckpts (int)
: number of checkpoints to keep on disk, optional. Default:None
.async_save (bool)
: whether to use asynchronous saving method. Default:False
.zero1_optimizer (bool):
: whether the optimizer state is from a zero1 optimizer, used when optimizer is a dict
Save LoRA Checkpoint:
NxD also uses neuronx_distributed.trainer.save_checkpoint()
to save LoRA models, but it can set save_lora_base
and merge_lora
in LoraConfig to specify how to save LoRA checkpoint.
There are three modes for LoRA checkpoint saving:
save_lora_base=False, merge_lora=False
: Save the LoRA adapter only.save_lora_base=True, merge_lora=False
: Save both the base model and the LoRA adapter seperately.save_lora_base=True, merge_lora=True
: Merge the LoRA adapter into the base model and then save the base model.
Other than the adapter, NxD also needs to save the LoRA configuration file for LoRA loading. The configuration can be saved into the same checkpoint with the adapter, or saved as a seperately json file.
save_lora_config_adapter (bool)
: If False, save the configuration file as a seperately json file.
Note that if LoRA configuration file is saved separately, it is named as lora_adapter/adapter_config.json
.
A configuration example to save the LoRA adapter only is
lora_config = neuronx_distributed.modules.lora.LoraConfig(
...
save_lora_base=False,
merge_lora=False,
save_lora_config_adapter=True,
)
Load Checkpoint:#
Method to load checkpoint saved by save_checkpoint
, return user contents if exists otherwise None
.
If tag
not provided, will try to use the newest tag tracked by save_checkpoint
.
Note that the checkpoint to be loaded must have the same model parallel degrees as in current use, and if ZeRO-1 optimizer is used, must use the same data parallel degrees.
def neuronx_distributed.trainer.load_checkpoint(
path,
tag=None,
model=None,
optimizer=None,
scheduler=None,
num_workers=8,
strict=True,
)
Parameters:
path (str)
: path to load the checkpoints.tag (str)
: tag to load the checkpoints.model (torch.nn.Module)
: model to load, optional.optimizer (torch.optim.Optimizer)
: optimizer to load, optional.scheduler
: scheduler to load, optional.num_workers (int)
: num of processes loading data on host at the same time.
This is done to avoid the host OOM, range: 1-32.
- strict (bool)
: whether to use strict mode when loading model checkpoint. Default: True
.
Load LoRA Checkpoint:
NxD loads LoRA checkpoints by setting flags in LoraConfig.
load_lora_from_ckpt
: Resumes the checkpoint process.lora_save_dir
: Load LoRA checkpoint from the specified folderlora_load_tag
: Load the LoRA checkpoint with the specified tag
An example is:
lora_config = LoraConfig(
enable_lora=True,
load_lora_from_ckpt=True,
lora_save_dir=checkpoint_dir, # checkpoint path
lora_load_tag=tag, # sub-directory under checkpoint path
)
nxd_config = nxd.neuronx_distributed_config(
...
lora_config=lora_config,
)
model = nxd.initialize_parallel_model(nxd_config, ...)
The NxD model with be initialized with LoRA enabled and LoRA weights loaded. LoRA-related configurations are the same as the LoRA adapter checkpoint.
Sample usage:
import neuronx_distributed as nxd
# create config
nxd_config = nxd.neuronx_distributed_config(
tensor_parallel_size=8,
optimizer_config={"zero_one_enabled": True, "grad_clipping": True, "max_grad_norm": 1.0},
)
# wrap model
model = nxd.initialize_parallel_model(nxd_config, get_model)
# wrap optimizer
optimizer = nxd.initialize_parallel_optimizer(nxd_config, AdamW, model.parameters(), lr=1e-3)
...
(training loop):
loss = model.run_train(inputs)
optimizer.step()
...
# loading checkpoint (auto-resume)
user_content = nxd.load_checkpoint(
"ckpts",
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
...
# saving checkpoint
nxd.save_checkpoint(
"ckpts",
nxd_config=nxd_config,
model=model,
optimizer=optimizer,
scheduler=scheduler,
user_content={"total_steps": total_steps},
)
Modules:#
GQA-QKV Linear Module:#
class neuronx_distributed.modules.qkv_linear.GQAQKVColumnParallelLinear(
input_size, output_size, bias=True, gather_output=True,
sequence_parallel_enabled=False, dtype=torch.float32, device=None, kv_size_multiplier=1)
This module parallelizes the Q,K,V linear projections using ColumnParallelLinear layers. Instead of using 3 different linear layers, we can replace it with a single QKV module. In case of GQA module, the number of Q attention heads are N times more than the number of K and V attention heads. The K and V attention heads are replicated after projection to match the number of Q attention heads. This helps to reduce the K and V weights and is useful especially during inference. However, in case of training these modules, it restricts the tensor-parallel degree that can be used, since the attention heads should be divisible by tensor-parallel degree. Hence, to mitigate this bottleneck, the GQAQKVColumnParallelLinear takes in a kv_size_multiplier argument. The module would replicate the K and V weights kv_size_multiplier times thereby allowing you to use higher tensor-parallel degree. Note: here instead of replicating the projection N/tp_degree times, we end of replicating the weights kv_size_multiplier times. This would produce the same result, allow you to use higher tp_degree degree, however, it would result in extra memory getting consumed.
Parameters:
input_size: (int)
: First dimension of the weight matrixoutput_sizes: (List[int])
: A list of second dimension of the Q and K/V weight matrixbias: (bool)
: If set to True, bias would be addedgather_output: (bool)
If true, call all-gather on output and make Y available to allNeuron devices, otherwise, every Neuron device will have its output which is Y_i = XA_i
sequence_parallel_enabled: (bool)
When sequence-parallel is enabled, it would gatherthe inputs from the sequence parallel region and perform the forward and backward passes
init_method: (torch.nn.init)
: Initialization function for the Q and K/V weights.dtype: (dtype)
: Datatype for the weightsdevice: (torch.device)
Device to initialize the weights on. By default, the weightswould be initialized on CPU
kv_size_multiplier: (int)
: Factor by which the K and V weights would be replicated along the first dimension
Checkpointing:#
These are set of APIs for saving and loading the checkpoint. These APIs take care of saving and loading the shard depending the tensor parallel rank of the worker.
Save Checkpoint:#
def neuronx_distributed.parallel_layers.save(state_dict, save_dir, save_serially=True, save_xser: bool=False, down_cast_bf16=False)
Note
This method will be deprecated, use neuronx_distributed.trainer.save_checkpoint
instead.
This API will save the model from each tensor-parallel rank in the
save_dir . Only workers with data parallel rank equal to 0 would be
saving the checkpoints. Each tensor parallel rank would be creating a
tp_rank_ii_pp_rank_ii
folder inside save_dir
and each ones saves its shard
in the tp_rank_ii_pp_rank_ii
folder.
If save_xser
is enabled, the folder name would be tp_rank_ii_pp_rank_ii.tensors
and there will be a ref data file named as tp_rank_ii_pp_rank_ii
in save_dir for each rank.
Parameters:
state_dict: (dict)
: Model state dict. Its the same dict that you would save using torch.savesave_dir: (str)
: Model save directory.save_serially: (bool)
: This flag would save checkpoints one model-parallel rank at a time. This is particularly useful when we are checkpointing large models.save_xser: (bool)
: This flag would save the model with torch xla serialization. This could significantly reduce checkpoint saving time when checkpointing large model, so it’s recommended to enable xser when the model is large. Note that if a checkpoint is saved withsave_xser
, it needs to be loaded withload_xser
, vice versa.down_cast_bf16: (bool)
: This flag would downcast the state_dict to bf16 before saving.
Load Checkpoint#
def neuronx_distributed.parallel_layers.load(
load_dir, model_or_optimizer=None, model_key='model', load_xser=False, sharded=True)
Note
This method will be deprecated, use neuronx_distributed.trainer.load_checkpoint
instead.
This API will automatically load checkpoint depending on the tensor parallel rank. For large models, one should pass the model object to the load API to load the weights directly into the model. This could avoid host OOM, as the load API would load the checkpoints for one tensor parallel rank at a time.
Parameters:
load_dir: (str)
: Directory where the checkpoint is saved.model_or_optimizer
: (torch.nn.Module or torch.optim.Optimizer): Model or Optimizer object.model
: (torch.nn.Module or torch.optim.Optimizer): Model or Optimizer object, equivilant tomodel_or_optimizer
model_key: (str)
: The model key used when saving the model in the state_dict.load_xser: (bool)
: Load model with torch xla serialization. Note that if a checkpoint is saved withsave_xser
, it needs to be loaded withload_xser
, vice versa.sharded: (bool)
: If the checkpoint is not sharded, pass False. This is useful (especially during inference) when the model is trained using a different strategy and you end up saving a single unsharded checkpoint. You can then load this unsharded checkpoint onto the sharded model. When this attribute is set toFalse
, it is necessary to pass the model object. Note: The keys in the state-dict should have the same name as in the model object, else it would raise an error.
Gradient Clipping:#
With tensor parallelism, we need to handle the gradient clipping as we have to accumulate the total norm from all the tensor parallel ranks. This should be handled by the following API
def neuronx_distributed.parallel_layers.clip_grad_norm(
parameters, max_norm, norm_type=2)
Parameters:
parameters (Iterable[Tensor] or Tensor)
: an iterable of Tensors or a single Tensor that will have gradients normalizedmax_norm (float or int)
:max norm of the gradientsnorm_type (float or int)
: type of the used p-norm. Can be ‘inf’ for infinity norm.
Neuron Zero1 Optimizer:#
In Neuronx-Distributed, we built a wrapper on the Zero1-Optimizer present in torch-xla.
class NeuronZero1Optimizer(Zero1Optimizer)
This wrapper takes into account the tensor-parallel degree and computes the grad-norm accordingly. It also provides two APIs: save_sharded_state_dict and load_sharded_state_dict. As the size of the model grows, saving the optimizer state from a single rank can result in OOMs. Hence, the api to save_sharded_state_dict can allow saving states from each data-parallel rank. To load this sharded optimizer state, there is a corresponding load_sharded_state_dict that allows each rank to pick its corresponding shard from the checkpoint directory.
optimizer_grouped_parameters = [
{
"params": [
p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
],
"weight_decay": 0.01,
},
{
"params": [
p for n, p in param_optimizer if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]
optimizer = NeuronZero1Optimizer(
optimizer_grouped_parameters,
AdamW,
lr=flags.lr,
pin_layout=False,
sharding_groups=parallel_state.get_data_parallel_group(as_list=True),
grad_norm_groups=parallel_state.get_tensor_model_parallel_group(as_list=True),
)
The interface is same as Zero1Optimizer in torch-xla
save_sharded_state_dict(output_dir, save_serially = True)
Note
This method will be deprecated, use neuronx_distributed.trainer.save_checkpoint
instead.
Parameters:
output_dir (str)
: Checkpoint directory where the sharded optimizer states need to be savedsave_serially (bool)
Whether to save the states one data-parallel rank at a time. This isespecially useful when we want to checkpoint large models.
load_sharded_state_dict(output_dir, num_workers_per_step = 8)
Note
This method will be deprecated, use neuronx_distributed.trainer.load_checkpoint
instead.
Parameters:
output_dir (str)
: Checkpoint directory where the sharded optimizer states are savednum_workers_per_step (int)
: This argument controls how many workers are doing model load in parallel.
Neuron PyTorch-Lightning#
Neuron PyTorch-Lightning is currently based on Lightning version 2.1.0, and will eventually be upstreamed Lightning-AI code base
Neuron Lightning Module#
Inherited from LightningModule
class neuronx_distributed.lightning.NeuronLTModule(
model_fn: Callable,
nxd_config: Dict,
opt_cls: Callable,
scheduler_cls: Callable,
model_args: Tuple = (),
model_kwargs: Dict = {},
opt_args: Tuple = (),
opt_kwargs: Dict = {},
scheduler_args: Tuple = (),
scheduler_kwargs: Dict = {},
grad_accum_steps: int = 1,
log_rank0: bool = False,
manual_opt: bool = True,
)
Parameters:
model_fn
: Model function to create the actual modelnxd_config
: Neuronx Distributed Config, output of neuronx_distributed.neuronx_distributed_configopt_cls
: Callable to create optimizerscheduler_cls
: Callable to create schedulermodel_args
: Tuple of args fed to model callablemodel_kwargs
: Dict of keyworded args fed to model callableopt_args
: Tuple of args fed to optimizer callableopt_kwargs
: Dict of keyword args fed to optimizer callablescheduler_args
: Tuple of args fed to scheduler callablescheduler_args
: Dict of keyworded args fed to scheduler callablegrad_accum_steps
: Grad accumulation stepslog_rank0
: Log at rank 0 (by default it will log at the last PP rank). Note that setting this to True will introduce extra communication per step hence causing performance dropmanual_opt
: Whether to do manual optimization, note that currently NeuronLTModule doesn’t support auto optimization so this should always set to True
Neuron XLA Strategy#
Inherited from XLAStrategy
class neuronx_distributed.lightning.NeuronXLAStrategy(
nxd_config: Dict = None,
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
save_load_xser: bool = True,
)
Parameters:
nxd_config
: Neuronx Distributed Config, output of neuronx_distributed.neuronx_distributed_configtensor_parallel_size
: Tensor parallel degree, only needed when nxd_config is not specifiedpipeline_parallel_size
: Pipeline parallel degree, only needed when nxd_config is not specified (Note that for now we only support TP with Neuron-PT-Lightning)save_load_xser
: Set to True will enable save/load with xla serialization, for more context check Save Checkpoint
Neuron XLA Precision Plugin#
Inherited from XLAPrecisionPlugin
class neuronx_distributed.lightning.NeuronXLAPrecisionPlugin
Neuron TQDM Progress Bar#
Inherited from TQDMProgressBar
class neuronx_distributed.lightning.NeuronTQDMProgressBar
Neuron TensorBoard Logger#
Inherited from TensorBoardLogger
class neuronx_distributed.lightning.NeuronTensorBoardLogger(save_dir)
Parameters:
save_dir
: Directory to save the log files
This document is relevant for: Inf2
, Trn1
, Trn1n