
Neuronx Distributed Training framework is built on top of NeuronxDistributed (NxD) , NeMo libraries and PyTorch-Lightning. The guide below will provide a step-by-step instructions on how to setup the environment to run training using NeuronX Distributed Training framework.

Setup a python Virtual Environment#

Let’s first setup a virtual env for our development. This can be done using the command below:

python3 -m venv env
source env/bin/activate

Installing Neuron Dependencies#

Install the neuron packages using the command:

pip install -U pip
pip install --upgrade neuronx-cc==2.* torch-neuronx torchvision neuronx_distributed --extra-index-url

Building Apex#

NxD Training uses the NeMo toolkit, which requires you to install additional dependencies. One of these dependencies is the Apex library. The NeMo toolkit uses this library for several fused module implementations.


NeMo used to use Apex for all distributed training APIs. Since we are using NxD for the same purpose, the use of Apex for this framework is very minimal. It’s been added as a dependency since some of the minor imports inside NeMo will break without it. Hence, when building Apex, we build a slim CPU version using the instructions below:

  1. Clone Apex repo

git clone
cd apex
git checkout 23.05
  1. Replace the contents of the with the following contents:

import sys
import warnings
import os
from packaging.version import parse, Version

from setuptools import setup, find_packages
import subprocess

import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME, load

        exclude=("build", "csrc", "include", "tests", "dist", "docs", "tests", "examples", "apex.egg-info",)
    description="PyTorch Extensions written by NVIDIA",
  1. Install python dependencies:

pip install packaging wheel
  1. Build the wheel using the command:

python bdist_wheel
  1. After this, you should see the wheel at dist/. You can use this for installation in the next section.

  2. Come out of the apex directory using cd ...

Installing the requirements#

Download the requirements.txt using the command:


We can now install the dependencies of the library using the following command:

pip install -r requirements.txt ~/apex/dist/apex-0.1-py3-none-any.whl

Installing Neuronx Distributed Training framework#

To install the library, one can run the following command:

pip install neuronx_distributed_training --extra-index-url

Common failures during installation#

This section goes over the common failures one can see during setup and how to resolve them.

  1. ``ModuleNotFoundError: No module named ‘Cython’``

    You may have to install Cython explicitly using pip install Cython

  2. Error while building ``youtokentome``

    If you get an error that says Python.h file not found, you may have to install python-dev and recreate the virtual env. To install python-dev, you can use the command: sudo apt-get install python-dev

  3. Mismatched torch and torch-xla version

    When you see an error that looks like:

 ImportError: env/lib/python3.10/site-packages/ undefined symbol: _ZN3c109TupleTypeC1ESt6vectorINS_4Type24SingletonOrSharedTypePtrIS2_EESaIS4_EENS_8optionalINS_13QualifiedNameEEESt10shared_ptrINS_14FunctionSchemaEE

It indicates that the major versions of torch and torch-xla don't match.


If you install torch again, make sure to install the corresponding torchvision version else that would have a conflict.

  1. Torch vision version error

    The below error indicates incorrect torchvision version. If installing torch=2.1, install torchvision=0.16 (This link shows which version of torchvision is compatible with which version of torch).

ValueError: Could not find the operator torchvision::nms. Please make sure you have already registered the operator
and (if registered from C++) loaded it via torch.ops.load_library.`
  1. Matplotlib lock error

    If you see the below error:

 TimeoutError: Lock error: Matplotlib failed to acquire the following lock file

This error means there is some contention in compute/worker nodes to access the matlotlib cache, and hence the timeout
error. To resolve this error, add or run ``python -c 'import matplotlib.pyplot as plt'`` command as part of your setup.
This will create a matplotlib cache and avoid the race condition.