Migrating from Neuron-NeMo-Megatron to Neuronx Distributed Training#
In this section, we go over the changes one would have to make if they are migrating their training workload from Neuronx-NeMo-Megatron (NNM) to Neuronx Distributed Training (NxDT) framework.
Config migration#
NxDT is a framework built on top of NeMo and NeuronxDistributed (NxD) and supports megatron-style model. The megatron model implementation is ported over from NNM. Hence, most of the config YAMLs from NNM can be migrated to use NxDT.
When building NxDT for the sake of modularity, we grouped certain parameters together, eg. distributed_strategy has all the configuration for model parallelism, data config now holds all the parameters required to configure the dataset.
At a high level, there are some differences with the NNM config, which are highlighted below:
The overall config structure has changed. For simplicity and ease of understanding, the config parameters are grouped according to their high level use case. For example, previously all the distributed config parameters used to reside inside
model
config, now it’s been moved to adistributed_config
of its own. Similarly data config is moved out to have clear separation between model and data.Environment variables like
neuron_cc_flags
andneuron_compile_cache_url
can be set from the config itself. There is no need to set the environment variables. The rationale is to avoid having to configure training scripts from multiple places.Activation Checkpointing:
NxDT only supports selective and full activation checkpointing. Theselective
checkpointing is done only for theCoreAttention
block (in case of llama3-8K we recompute theMLP
block, too) andfull
activation checkpointing is done only at a layer boundary. NxDT doesn’t support config parameters likeactivations_checkpoint_method
,activations_checkpoint_num_layers
,num_micro_batches_with_partial_activation_checkpoints
,activations_checkpoint_layers_per_pipeline
,disable_layer_norm_checkpointing
. Please remove these parameters from your config.yaml file.
Note
If you plan to add more modules that need to be recomputed, one would have to override the checkpointing config inside
ModelModule
(refer to build_model
API at Build a Lightning Module)
and add the modules that need to be recomputed.
Tokenizer:
The tokenizer which used to reside undermodel
is now moved todata
. This is done so that all data related configuration can reside at one place.accumulate_grad_batches:
This param is removed since it should always be 1. Gradient accumulation is handled by setting the global_batch_size and micro_batch_size along with data-parallel degree.pre_process and post_process:
: These two parameters were added to the model to decide if the embedding lookup needs to be added at the start and if apooler
layer needs to be added at the end. This has been set by default for all decoder models and hence the config param is no longer exposed.Mixed precision config:
NxDT no longer exposes NeMo mixed precision parameters:native_amp_init_scale
,native_amp_growth_interval
,hysteresis
,fp32_residual_connection
,fp16_lm_cross_entropy
. All these parameters are specific to the GPU mixed precision strategy, which Neuron doesn’t support, or they are not applicable. Neuron has a different way to enable mixed precision training throughmaster_weights
andfp32_grad_accumulation
.megatron_amp_o2:
This parameter is not supported.Fusions:
Neuron doesn’t support fusion parameters likegrad_div_ar_fusion
,gradient_accumulation_fusion
,bias_activation_fusion
,bias_dropout_add_fusion
,masked_softmax_fusion
. All of these fusions are built for GPU and require CUDA kernels which cannot run on Trn1. Neuron would have its own set of kernels and when we support them, we would enable those parameters from the config.
Note
If there is a need to support these configs, please create a feature request with exact needs and we shall work on it.
For detailed mapping, please check the Config Mapping.
Model code#
There are the following differences in the model code:
NNM used Apex to get all the distributed parallel layers and schedules. Since NxDT uses NxD as the base library, all the parallel layers/parallel state are coming from NxD. Eg. apex.parallel_state is replaced with nxd.parallel_layers.parallel_state.
NNM explicitly creates a module for each pipeline-parallel (PP) rank, however, NxDT uses NxD which does the partitioning under the hood. Hence, users no longer have to worry about creating a rank specific module. They can create one single model and NxD’s PP wrapper takes care of sharding for each PP rank. Hence, all the code related to pipeline parallelism inside model code is removed. The model code assumes there is no PP and just uses TP layers from NxD.
Note
For the tracer to work efficiently, we configure the pipeline parallel config inside the BaseModelModule
class inside
lightning_modules/model
.
In NNM, megatron module had to explicitly handle gradient reduction for shared weights across PP ranks. In NxDT, since we are using NxD’s PP wrapper, all that is handled for the user.
For activation checkpointing, NNM had explicit recompute functions which handled the custom forward API. With NxDT, NxD’s Activation Checkpoint wrapper handles the recompute of the modules. Users just have to configure the
activation_checkpoint_config
insidenxd_config
here.
Checkpointing Save/Load#
NxDT supports all the checkpointing features which NNM supports. This includes async checkpointing, auto-resume, etc. There are some differences in the format of the checkpoint. This is because NxDT uses NxD’s checkpoint api. The key differences are listed below:
NNM combines the model weights, optimizers and other state_dicts into a single
state_dict
and dump a file of the format:tp_rank_0*_pp_rank_00*/model_optim_rng.ckpt
. However, with NxDT, we save the modelstate_dict
and the optimizer separately. The modelstatedict
is saved in a folder of the form:model/dp_rank_00_tp_rank_00_pp_rank_00.pt
and the optimizer is saved into a separate folder as:optim/dp_rank_00_tp_rank_00_pp_rank_00.pt
. This is mainly done so that when we use zero1, each DP rank can save its own optimizer shard.In NNM, if we are using pipeline parallelism, each pipeline stage creates an independent model. So lets say we have a model with 32 layers and we use PP=4, then NNM would create 4 chunks with layers 0-7. So each PP rank would have
model_state_dict
with keys going from layer-0-7. However, in NxDT, the model is created as a whole and then sharded. So the layer numbers are preserved.There are checkpoint conversion scripts provided under
examples/
of NxDT repository to convert the existing NNM checkpoints to NxDT format in case of migrating in the middle of training.
python nnm_nxdt_ckpt_converter.py --tp 8 --pp 4 --n_layers 32 --nnm_ckpt_path {path_to_ckpt}/ckpt/nnm --nxdt_ckpt_path {path_to_ckpt}/nnm-converted-nxdt-ckpt/ --enable_parallel_processing True --num_parallel_processes 8
Config Mapping#
Here is a detailed mapping for all the parameters in the config file. For the below mapping, we chose the Llama-7B example across NNM and NxDT frameworks. The same mapping is also true for other models.
NNM param |
NxDT param mapping |
Comments |
---|---|---|
name |
name |
|
restore_from_path |
Not supported |
This config was not fully supported in NNM, either. |
trainer |
||
devices |
devices |
|
num_nodes |
num_nodes |
|
accelerator |
Not required |
We made the default as TPU which maps to Neuron internally, so users no longer have to add it. |
precision |
replaced by |
There is a separate precision config to control the precision of model and optimizer. |
logger |
Replaced by default |
We made the NNM logger default in NxDT. |
enable_checkpointing |
Separate |
All checkpointing is controlled by exp_manager config. |
replace_sampler_ddp |
Not supported |
Had to be always False in NNM, made it default in NxDT. No setting required. |
max_epochs |
max_epochs |
|
max_steps |
max_steps |
|
log_every_n_steps |
log_every_n_steps |
|
val_check_interval |
val_check_interval |
|
limit_val_batches |
limit_val_batches |
|
limit_test_batches |
limit_test_batches |
|
accumulate_grad_batches |
Removed |
This is automatically configured based on global_batchsize, micro-batchsize and distributed config. |
gradient_clip_val |
gradient_clip_val |
|
benchmark |
Not supported |
|
enable_model_summary |
Not supported |
|
exp_manager |
||
log_local_rank_0_only |
log_local_rank_0_only |
|
create_tensorboard_logger |
create_tensorboard_logger |
|
explicit_log_dir |
explicit_log_dir |
|
exp_dir |
exp_dir |
|
name |
name |
|
create_wandb_logger |
Not supported |
This was not supported under NNM, either. We have removed this argument from NxDT. |
wandb_logger_kwargs |
Not supported |
|
resume_if_exists |
resume_if_exists |
|
resume_ignore_no_checkpoint |
resume_ignore_no_checkpoint |
|
create_checkpoint_callback |
create_checkpoint_callback |
|
checkpoint_callback_params |
checkpoint_callback_params |
|
model |
||
tensor_model_parallel_size |
|
All the parallelism config are moved to distributed_strategy config. |
pipeline_model_parallel_size |
|
|
virtual_pipeline_model_parallel_size |
|
|
sequence_parallel |
|
|
wrap_with_zero |
|
|
micro_batch_size |
|
All the dataset/dataloader/tokenizer configurations are now part of a separate config called data. |
global_batch_size |
|
|
tokenizer |
|
|
data |
Moved to |
The entire |
encoder_seq_length |
encoder_seq_length |
|
max_position_embeddings |
max_position_embeddings |
|
make_vocab_size_divisible_by |
make_vocab_size_divisible_by |
|
pre_process |
Not supported |
NxDT by default adds embedding layer at the start of the transformer block. |
post_process |
Not supported |
NxDT by default adds a LM-head at the end of the transformer block. |
persist_layer_norm |
persist_layer_norm |
|
share_embeddings_and_output_weights |
share_embeddings_and_output_weights |
|
position_embedding_type |
position_embedding_type |
|
rotary_percentage |
rotary_percentage |
|
transformer_block_type |
transformer_block_type |
|
has_bias |
has_bias |
|
native_amp_init_scale |
Not required |
|
native_amp_growth_interval |
Not required |
GPU optimizations which were not supported in NNM, have been removed from NxDT. Most of these fusion ops, the neuron compiler handles on its own. For Attention and Softmax, Neuron uses NKI kernels and custom ops to implement them. |
hysteresis |
Not required |
|
fp32_residual_connection |
Not required |
|
fp16_lm_cross_entropy |
Not required |
|
megatron_amp_O2 |
Not required |
|
grad_div_ar_fusion |
Not required |
|
gradient_accumulation_fusion |
Not required |
|
bias_activation_fusion |
Not required |
|
bias_dropout_add_fusion |
Not required |
|
masked_softmax_fusion |
|
|
seed |
Seed is moved out of model and at the same level as |
|
resume_from_checkpoint |
|
|
use_cpu_initialization |
use_cpu_initialization |
|
onnx_safe |
Not supported |
This was not supported under NNM, either. We have removed this argument from NxDT. |
apex_transformer_log_level |
Not supported |
|
gradient_as_bucket_view |
Not supported |
|
sync_batch_comm |
Not supported |
|
log_parameter_norm |
|
|
log_gradient_norm |
|
|
flexible_pipeline_parallel_stages |
Not supported |
|
activations_checkpoint_granularity |
activations_checkpoint_granularity |
Currently, NxDT checkpoints the attention module in case of selective and a single layer in case of full checkpointing. |
activations_checkpoint_method |
Not supported |
|
activations_checkpoint_num_layers |
Not supported |
|
num_micro_batches_with_partial_activation_checkpoints |
Not supported |
|
activations_checkpoint_layers_per_pipeline |
Not supported |
|
disable_layer_norm_checkpointing |
Not supported |
|
zero_use_master_weight |
Supported via precision config |
|
zero_use_fp32_grad_accum |
Supported via precision config |
|
transformer_engine |
Not supported |
This is specifically built for NVIDIA GPUs. |
fp8 |
Not supported |
fp8 training is not supported on Neuron (both NNM and NxDT). |
fp8_e4m3 |
Not supported |
fp8 training is not supported on Neuron (both NNM and NxDT). |
fp8_hybrid |
Not supported |
fp8 training is not supported on Neuron (both NNM and NxDT). |
fp8_margin |
Not supported |
fp8 training is not supported on Neuron (both NNM and NxDT). |
use_emha |
Not supported |
fp8 training is not supported on Neuron (both NNM and NxDT). |
convert_to_hf |
Supported via separate script |
|
nsys_profile |
Not supported |
This is specifically built for NVIDIA GPUs. |
optim |
optim |
|
enable_recovery_time_instrumentation |
|
|
async_checkpointing |
|
Note
For parameters that are not supported by NxDT, please create a feature request with specific use-case for the parameter, if needed.