.. _pytorch-neuronx-programming-guide: Developer Guide for Training with PyTorch NeuronX =================================================== .. contents:: Table of Contents :local: :depth: 2 Trainium is designed to speed up model training and reduce training cost. It is available on the Trn1 instances. Each Trainium accelerator has two NeuronCores, which are the main neural network compute units. PyTorch NeuronX enables PyTorch users to train their models on Trainium's NeuronCores with little code change to their training code. It is based on the `PyTorch/XLA software package `__. This guide helps you get started with single-worker training and distributed training using PyTorch Neuron. PyTorch NeuronX ---------------- Neuron XLA device ~~~~~~~~~~~~~~~~~ With PyTorch NeuronX the default XLA device is mapped to a NeuronCore. By default, one NeuronCore is configured. To use Neuron XLA device, specify the device as ``xm.xla_device()`` or ``'xla'``: .. code:: python import torch_xla.core.xla_model as xm device = xm.xla_device() or .. code:: python device = 'xla' PyTorch models and tensors can be mapped to the device as usual: .. code:: python model.to(device) tensor.to(device) To move tensor back to CPU, do : .. code:: python tensor.cpu() or .. code:: python tensor.to('cpu') PyTorch NeuronX single-worker training/evaluation quick-start -------------------------------------------------------------- PyTorch NeuronX uses XLA to enable conversion of PyTorch operations to Trainium instructions. To get started on PyTorch NeuronX, first modify your :ref:`training script ` to use XLA in the same manner as described in `PyTorch/XLA documentation `__ and use XLA device: .. code:: python import torch_xla.core.xla_model as xm device = xm.xla_device() # or device = 'xla' The NeuronCore is mapped to an XLA device. On Trainium instance, the XLA device is automatically mapped to the first available NeuronCore. By default the above steps will enable the training or evaluation script to run on one NeuronCore. NOTE: Each process is mapped to one NeuronCore. Finally, add ``mark_step`` at the end of the training or evaluation step to compile and execute the training or evaluation step: .. code:: python xm.mark_step() These changes can be placed in control-flows in order to keep the script the same between PyTorch Neuron and CPU/GPU. For example, you can use an environment variable to disable XLA which would cause the script to run in PyTorch native mode (using CPU on Trainium instances and GPU on GPU instances): .. code:: python device = 'cpu' if not os.environ.get("DISABLE_XLA", None): device = 'xla' ... # end of training step if not os.environ.get("DISABLE_XLA", None): xm.mark_step() More on the need for mark_step is at `Understand the lazy mode in PyTorch Neuron <#understand-the-lazy-mode-in-pytorch-neuron>`__. For a full runnable example, please see the :ref:`Single-worker MLP training on Trainium tutorial `. PyTorch NeuronX multi-worker data parallel training using torchrun ----------------------------------------------------------------- Data parallel training allows you to replicate your script across multiple workers, each worker processing a proportional portion of the dataset, in order to train faster. To run multiple workers in data parallel configuration, with each worker using one NeuronCore, first add additional imports for parallel dataloader and multi-processing utilities: :: import torch_xla.distributed.parallel_loader as pl Next we initialize the Neuron distributed context using the XLA backend for torch.distributed: :: import torch_xla.distributed.xla_backend torch.distributed.init_process_group('xla') Next, replace ``optimizer.step()`` function call with ``xm.optimizer_step(optimizer)`` which adds gradient synchronization across workers before taking the optimizer step: :: xm.optimizer_step(optimizer) If you're using a distributed dataloader, wrap your dataloader in the PyTorch/XLA's ``MpDeviceLoader`` class which provides buffering to hide CPU to device data load latency: :: parallel_loader = pl.MpDeviceLoader(dataloader, device) Within the training code, use xm.xrt_world_size() to get the world size, and xm.get_ordinal to get the global rank of the current process. Then run use `PyTorch torchrun `__ utility to run the script. For example, to run 32 worker data parallel training: ``torchrun --nproc_per_node=32