This document is relevant for: Trn1, Trn1n

ZeRO-1 Tutorial#

What is ZeRO-1?#

ZeRO-1 (Zero Redundancy Optimizer Stage 1, is an optimization technique for large-scale deep learning models. It is a memory efficient variation of data parallelism. ZeRO leverages the aggregate computation and memory resources of data parallelism to reduce the memory and compute requirements of each accelerator used for model training. ZeRO reduces the memory consumption of each accelerator by partitioning the various model training states (weights, gradients, and optimizer states) across the available devices in the distributed training hardware. ZeRO is being implemented as incremental stages of optimizations. In stage 1, the optimizer states (e.g., for Adam optimizer, 32-bit weights, and the first, and second moment estimates) are partitioned across the processes, so that each process updates only its partition.

Image: zero1.jpg

We implemented an XLA-friendly version of ZeRO-1 and it has been merged in open-source PyTorch/XLA project. Users can use it to enable ZeRO-1 algorithm by simply wrapping the origin optimizer as shown below.

# Before:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# After
optimizer = ZeroRedundancyOptimizer(model.parameters(), torch.optim.Adam, lr=0.0001)

Then just call optimizer.step() directly, the wrapped optimizer will handle the distributed operations automatically.

The above code snippet illustrates the basic usage. Generally, users can use ZeRO-1 optimizer like a normal optimizer. In addition, ZeroRedundancyOptimizer also provides other features: enable gradient clipping or use other data type for wrapped optimizer. Note that though the most of optimizers can be used with ZeRO-1, optimizers that compute norm for parameters (e.g. LAMB) might lead to accuracy disparities compared to using original local optimizer when using ZeRO-1, because these optimizers cannot get full parameters but shards.


To enable ZeRO-1 optimizer, just import it and replace origin optimizer with ZeRO-1 wrapped version

from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer

device = xm.xla_device()
model =

optimizer = ZeroRedundancyOptimizer(model.parameters(), AdamW, lr=0.001)

Then in training loop, just call optimizer.step() , note that we should not use xm.reduce_gradients() or xm.optimizer_step() as gradient reduction will be handle by ZeRO-1.


ZeRO-1 optimizer also provides some additional features, user can pass these arguments to the wrapper constructor:

  • Change optimizer_dtype to choose data dtype used by optimizer, default is torch.float32. For example, when parameter data type is bfloat16, set optimizer_dtype to be float32 to enable ‘master weight’.

  • Change grad_clipping to enable grad clipping, default is True.

  • Change max_norm to determine the maximum norm value used by grad clipping, default is 1.0.

  • Change use_grad_acc_hook to enable using buffers to store gradients, it will use the same data type as optimizer_dtype to accumulate gradients. (Added in neuron 2.19.0 release).

  • Change higher_cc_precision to force reduce-scatter operator to use the same data type as optimizer_dtype, default is False. When use_grad_acc_hook is True, it has no effects. (Added in neuron 2.19.0 release).

Note: ZeRO-1 optimizer now forces to use the same data type as parameters for all-gather operator. (Changed in neuron 2.19.0 release)

GPT2-XL Pretraining Tutorial#


We use single Trn1.32xlarge instance. Follow Install PyTorch Neuron on Trn1 to setup the environment first. For all the commands below, make sure you are in the virtual environment that you have created above before you run the commands:

requirements.txt: We pin the following Hugging Face Library versions necessary for the tutorial

source ~/aws_neuron_venv_pytorch/bin/activate
git clone
cd aws-neuron-samples/torch-neuronx/training/zero1_gpt2
python3 -m pip install -r requirements.txt

The specific files you need for this tutorial:

  • config_1p5B_gpt2.json: The model configuration used in the tutorial for GPT 2.7B Neo

  • includes utility functions and the logging tools

  • the main training script that runs the actual training

  • the shell script to launch the training job


For the dataset, we use the wikitext dataset, specifically wikitext-103-raw-v1, provided by the HuggingFace The data will be preprocessed the first time running through the training script and then preprocessed data will be cached in the HuggingFace cache directory for any future training runs.

If the main process downloads the dataset, tokenizes the data and groups them together successfully, the expected output would be as below at the beginning of the training.

***** Running training *****
  Num examples = 114248
  Num Epochs = 29
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 1
  Total optimization steps = 100000


The GPT2 python fine-tuning script is adapted from the example in It incorporates the Accelerate Given its beta stage, some modifications are needed, along with the bridge code to XLA. Particularly, some workarounds to support Accelerate for the training script are listed in “Known Issues Workarounds and Limitations” below.

In this example, we use GPT2-xl as example, and show the training steps with mixed precision (bfloat16 and float32)

  • single node training:

# compile graphs
neuron_parallel_compile bash MIXED wikitext-103-raw-v1
bash MIXED wikitext-103-raw-v1
  • multi-node training, run:

sbatch run_clm_compile.slurm


sbatch run_clm.slurm

Known Issues, Work-arounds and Limitations#

  1. Error message: ValueError: invalid literal for int() with base 10: ''. Simply re-run the script can solve this issue. This issue is already solved in the newer versions of transformers, see

  2. Accelerator API workarounds:

    • Error message: “Gradient accumulation is not supported on TPU. Please set gradient_accumulation_steps to 1 and don’t pass in a GradientAccumulationPlugin object.” More context here: The training still works by commenting out the assertion and avoid using the accumulation wrapper with accelerator.accumulate(model)

    • Accelerator.prepare call: We have noticed that using the optimizer returned by this API are not directly reusable. It is due to gaps in configuring accelerate API for XLA devices.

This document is relevant for: Trn1, Trn1n