This document is relevant for: Inf2, Trn1, Trn1n

Training Llama2 7B with Tensor Parallelism and ZeRO-1 Optimizer (neuronx-distributed )#

In this section, we showcase how to pre-train a Llama2 7B model on four Trn1.32xlarge instances using the Neuron Distributed library. We will use AWS ParallelCluster to orchestrate the training jobs. To train the LLama 7B model in this example, we will apply the following optimizations using the Neuron Distributed library:

  1. Tensor Parallelism

  2. Sequence Parallel

  3. Selective checkpointing

  4. ZeRO-1

Setting up environment:#

For this experiment, we will use AWS ParallelCluster with at least four Trn1.32xlarge compute nodes. Train your model on ParallelCluster introduces how to setup and use a ParallelCluster. To setup the packages on the headnode of the ParallelCluster, follow the instructions mentioned here: Install PyTorch Neuron on Trn1.

We also need to install the neuronx-distributed package inside the virtual env using the following command:

python -m pip install neuronx_distributed --extra-index-url https://pip.repos.neuron.amazonaws.com

Let’s download the scripts for pretraining:

  1. Creating a directory to hold our experiments

mkdir -p ~/examples/tp_zero1_llama2_7b_hf_pretrain
cd ~/examples/tp_zero1_llama2_7b_hf_pretrain
  1. Downloading training scripts for our experiments

wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/master/examples/training/llama2/tp_zero1_llama2_7b_hf_pretrain/tp_zero1_llama2_7b_hf_pretrain.py
wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/master/examples/training/llama2/tp_zero1_llama2_7b_hf_pretrain/logger.py
wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/master/examples/training/llama2/tp_zero1_llama2_7b_hf_pretrain/tp_zero1_llama2_7b_hf_pretrain.sh
wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/master/examples/training/llama2/training_utils.py
wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/master/examples/training/llama2/modeling_llama_nxd.py
wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/master/examples/training/llama2/get_dataset.py
wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/master/examples/training/llama2/requirements.txt
wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/master/examples/training/llama2/tp_zero1_llama2_7b_hf_pretrain/config.json
  1. Installing the additional requirements and giving the right permissions to our shell script

python3 -m pip install -r requirements.txt
chmod +x tp_zero1_llama2_7b_hf_pretrain.sh

Next, we tokenize our dataset. Note:` To tokenize the data, we must request the tokenizer from HuggingFace and Meta by following the instructions at the following link: HuggingFace Llama 2 7B Model . Use of the Llama 2 model is governed by the Meta license. In order to download the model weights and tokenizer, please visit the above website and accept their License before requesting access. After access has been granted, you may use the download scripts provided by Meta to download the model weights and tokenizer to your cluster.

Once you have downloaded the tokenizer and model weights, you can copy the tokenizer.model to the ~/examples/tp_zero1_llama2_7b_hf_pretrain directory.

Next let’s download and pre-process the dataset:

cd ~/examples/tp_zero1_llama2_7b_hf_pretrain
python3 get_dataset.py

Note: In case you see an error of the following form when downloading data: huggingface_hub.utils._validators.HFValidationError: Repo id must be in the form 'repo_name' or 'namespace/repo_name': '/home/ubuntu/examples/tp_zero1_llama2_7b_hf_pretrain'. Use `repo_type` argument if needed. This could be because of a stale cache. Try deleting the cache using:

sudo rm -rf /home/ubuntu/.cache/

At this point, you are all set to start training.

Running training#

By this step, the ParallelCluster is all setup for running experiments. Before we run training, we first pre-compile the graphs using the neuron_parallel_compile. Let’s run the command below:

sbatch --exclusive \
--nodes 4 \
--cpus-per-task 128 \
--wrap="srun neuron_parallel_compile bash $(pwd)/tp_zero1_llama2_7b_hf_pretrain.sh"

This script uses a tensor-parallel size of 8. This will automatically set the zero-1 sharding degree to 16 (4 * 32 workers / tensor_parallel_size).

Note: You can use any number of nodes in this case, would just need to adjust the number of nodes in the above slurm command accordingly. Also, the number of nodes used in parallel_compile command should be same as the actual training run. This is because, as the number of nodes change, the data-parallel degree would change too. This would result in more workers participating in operations like gradient all-reduce which would result in new graphs getting created.

Once the graphs are compiled we can now run training and observe our loss goes down. To run the training, we just run the above command but without neuron_parallel_compile.

sbatch --exclusive \
--nodes 4 \
--cpus-per-task 128 \
--wrap="srun bash $(pwd)/tp_zero1_llama2_7b_hf_pretrain.sh"

Performance:#

To achieve better performance, the script applies few techniques:

Sequence Parallelism and Selective Activation Checkpointing

As explained in the Activation Memory Recomputation Doc, both Sequence Parallelism and Selective activation checkpointing can help with activation memory reduction thereby allowing us to fit bigger models with less number of devices. Please refer to Activation Memory Reduction Developer Guide on how to enable sequence parallel and selective activation checkpointing.

Coalescing Q, K, V layers:

We coalesced parallel matrix multiply to improve throughput:

  • We coalesced query, key and value into one matrix multiply

  • We coalesced gate_proj and up_proj into one matrix multiply

Please check modeling_llama2_nxd.py and tp_dp_gpt_neox_20b_hf_pretrain.py for details. Note: Because we coalesced the layers above, the pretrained checkpoint provided here cannot be loaded out of the box for fine-tuning, and would require preprocessing. The Q,K,V layers and the gate_proj and up_proj layers need to be coalesced in the checkpoint before loading.

Logging:

Currently for better performance we log loss values every 10 steps. Logging frequently will result in frequent syncs between device and CPU which are expensive. Hence, it is recommended to do less frequent logging if possible.

Checkpointing:#

Currently by default, the checkpoint is saved at the end of training. You can modify that behaviour by saving the checkpoint after every N steps inside the training loop:

from neuronx_distributed.parallel_layers import checkpointing
if global_step % every_n_steps_checkpoint == 0:
   state_dict = {
      "model": model.state_dict(),
      "global_step": global_step,
      "epoch": epoch,
      "scheduler": scheduler.state_dict()
   }
   checkpointing.save(state_dict, flags.output_dir)
   optimizer.save_sharded_state_dict(flags.output_dir)

Here we have to save the model state_dict using the checkpointing.save API and the optimizer state_dict using the optimizer.save_sharded_state_dict. This is because, currently, checkpointing.save API only saves on data-parallel rank 0, while in case of Zero1 Optimizer, the optimizer states are distributed across all data-parallel ranks. Hence, we use Zero1 Optimizer’s save API to save the optimizer states.

Time to save a checkpoint:

Checkpoint save time can vary depending on what location the checkpoint is saved. If the checkpoint is saved in the home directory, the checkpointing time can be higher. The same time can be reduce by 4x if the checkpoint is dumped to FSX file system.

By default, checkpoint.save API allows one tensor-parallel rank at a time to save the checkpoint. This is done in order to avoid HOST OOM. When all tensor-parallel ranks try to save at the same time, they would end up copying weights to CPU at the same time. This can result in HOST OOM. Note: Since, we use XLA_DOWNCAST_BF16 flag for BF16 training, even though the weights on device are on bf16, the weights on CPU are copied in FP32 format. In case, you want to avoid this typecasting from BF16 to FP32 when copying weights from device to CPU for checkpoint saving, you can pass down_cast_bf16=True to the checkpointing.save API as follows:

from neuronx_distributed.parallel_layers import checkpointing
if global_step % every_n_steps_checkpoint == 0:
   state_dict = {
      "model": model.state_dict(),
      "global_step": global_step,
      "epoch": epoch,
      "scheduler": scheduler.state_dict()
   }
   checkpointing.save(state_dict, flags.output_dir, down_cast_bf16=True)

This should not only reduce the HOST memory pressure when saving weights, but at the same time reduce model checkpointing time by half. Note: We are saving checkpoint in sharded format, wherein each tensor-parallel rank is saving one shard. To deploy these pretrained models, one would have to combine these shards by loading them and concatenating the tensor-parallel layers together. (We are working on a checkpoint conversion script that combines the shards into a single checkpoint)

In addition to the above method, if we want to speed up checkpoint saving for the model further, we can do so by:

from neuronx_distributed.parallel_layers import checkpointing
if global_step % every_n_steps_checkpoint == 0:
   state_dict = {
      "model": model.state_dict(),
      "global_step": global_step,
      "epoch": epoch,
      "scheduler": scheduler.state_dict()
   }
   checkpointing.save(state_dict, flags.output_dir, down_cast_bf16=True, save_xser=True)

The save_xser uses torch-xla’s xser.save to save the tensors serially. This API will copy one tensor at a time to the disk. This will allow all the ranks to save the checkpoint at the same time. This speeds up checkpoint saving especially for large models as all ranks are saving at the same time. Moreover, the risk of HOST OOM is completely eliminated because only one tensor is copied to CPU at a time.

Note: If we use save_xser to save the checkpoint, we would have to pass load_xser to the checkpoint.load API. Also, if you use save_xser, the checkpoint folder would contain a .pt file for each tensor instead of a single .pt for the entire state_dict. To read this checkpoint in your checkpoint conversion script, you would have to use xser.load API instead of torch.load to load the checkpoint. The xser.load should load the serialized checkpoint and return the full state_dict.

Finally, to speed up optimizer saving time, you can increase the number of workers saving at the same time. This can be done as follows:

if global_step % every_n_steps_checkpoint == 0:
   ...
   optimizer.save_sharded_state_dict(flags.output_dir, num_workers_per_step=32)

By default, num_workers_per_step is set to 8.

This document is relevant for: Inf2, Trn1, Trn1n