This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3

Install JAX via Deep Learning Container#

Install JAX with Neuron support using pre-configured AWS Deep Learning Containers (DLCs).

⏱️ Estimated time: ~10 minutes

Prerequisites#

Requirement

Details

Instance Type

Inf2, Trn1, Trn2, or Trn3

Neuron Driver on Host

aws-neuronx-dkms installed on the host instance

Docker Installed

Docker engine running on the host instance

AWS Account

With EC2 permissions

Available container images#

Image

ECR URI

JAX Training

public.ecr.aws/neuron/jax-training-neuronx

Note

JAX DLCs are currently available for training workloads. For the full list of available images and tags, see JAX Training Containers.

For more information, see Neuron Deep Learning Containers.

Installation steps#

Step 1: Install Neuron driver on host

Configure the Neuron repository and install the driver:

. /etc/os-release
sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null <<EOF
deb https://apt.repos.neuron.amazonaws.com ${VERSION_CODENAME} main
EOF
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
sudo apt-get update
sudo apt-get install -y aws-neuronx-dkms

Step 2: Install and verify Docker

Install Docker and add your user to the docker group:

sudo apt-get install -y docker.io
sudo usermod -aG docker $USER

Log out and log back in to refresh group membership, then verify:

docker run hello-world

Step 3: Pull the DLC image from ECR

Pull the JAX Training DLC image:

docker pull public.ecr.aws/neuron/jax-training-neuronx:<image_tag>

Replace <image_tag> with the desired tag from the JAX Training Containers repository.

Step 4: Run the container

Launch the container with access to Neuron devices:

docker run -it \
  --device=/dev/neuron0 \
  --cap-add SYS_ADMIN \
  --cap-add IPC_LOCK \
  public.ecr.aws/neuron/jax-training-neuronx:<image_tag> \
  bash

Note

Adjust the --device flags based on your instance type. Use ls /dev/neuron* on the host to list available devices. For example, a trn1.32xlarge has 16 devices (/dev/neuron0 through /dev/neuron15).

Step 5: Verify inside the container

Run the following commands inside the container to confirm Neuron devices are visible and JAX is installed:

neuron-ls
python3 -c "import jax; print(f'JAX version: {jax.__version__}'); print(f'Devices: {jax.devices()}')"

Expected output:

+--------+--------+--------+-----------+
| DEVICE | CORES  | MEMORY | CONNECTED |
+--------+--------+--------+-----------+
| 0      | 2      | 32 GB  | Yes       |
+--------+--------+--------+-----------+

JAX version: 0.7.0
Devices: [NeuronDevice(id=0), NeuronDevice(id=1)]
⚠️ Troubleshooting: Device not found in container

If neuron-ls shows no devices inside the container:

  1. Verify the Neuron driver is installed on the host:

    # Run on the host (not inside the container)
    neuron-ls
    
  2. Confirm you passed the correct --device flag:

    ls /dev/neuron*
    
  3. Restart the container with the correct device path:

    docker run -it --device=/dev/neuron0 \
      --cap-add SYS_ADMIN --cap-add IPC_LOCK \
      public.ecr.aws/neuron/jax-training-neuronx:<image_tag> bash
    
⚠️ Troubleshooting: Permission denied

If you see permission denied errors when running Docker commands:

  1. Verify your user is in the docker group:

    groups
    # Should include "docker"
    
  2. If not, add yourself and re-login:

    sudo usermod -aG docker $USER
    # Log out and log back in
    
  3. Alternatively, run Docker with sudo:

    sudo docker run -it --device=/dev/neuron0 \
      --cap-add SYS_ADMIN --cap-add IPC_LOCK \
      public.ecr.aws/neuron/jax-training-neuronx:<image_tag> bash
    
⚠️ Troubleshooting: Image pull failure

If docker pull fails with a network or authentication error:

  1. Verify internet connectivity:

    curl -s https://public.ecr.aws/v2/ | head -1
    
  2. Check that the image tag exists by browsing the ECR Public Gallery.

  3. If you are behind a proxy, configure Docker proxy settings:

    sudo mkdir -p /etc/systemd/system/docker.service.d
    sudo tee /etc/systemd/system/docker.service.d/proxy.conf > /dev/null <<EOF
    [Service]
    Environment="HTTP_PROXY=http://proxy:port"
    Environment="HTTPS_PROXY=http://proxy:port"
    EOF
    sudo systemctl daemon-reload
    sudo systemctl restart docker
    

Step 1: Install Neuron driver on host

Configure the Neuron repository and install the driver:

. /etc/os-release
sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null <<EOF
deb https://apt.repos.neuron.amazonaws.com ${VERSION_CODENAME} main
EOF
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
sudo apt-get update
sudo apt-get install -y aws-neuronx-dkms

Step 2: Install and verify Docker

Install Docker and add your user to the docker group:

sudo apt-get install -y docker.io
sudo usermod -aG docker $USER

Log out and log back in to refresh group membership, then verify:

docker run hello-world

Step 3: Pull the DLC image from ECR

Pull the JAX Training DLC image:

docker pull public.ecr.aws/neuron/jax-training-neuronx:<image_tag>

Replace <image_tag> with the desired tag from the JAX Training Containers repository.

Step 4: Run the container

Launch the container with access to Neuron devices:

docker run -it \
  --device=/dev/neuron0 \
  --cap-add SYS_ADMIN \
  --cap-add IPC_LOCK \
  public.ecr.aws/neuron/jax-training-neuronx:<image_tag> \
  bash

Note

Adjust the --device flags based on your instance type. Use ls /dev/neuron* on the host to list available devices. For example, a trn1.32xlarge has 16 devices (/dev/neuron0 through /dev/neuron15).

Step 5: Verify inside the container

Run the following commands inside the container to confirm Neuron devices are visible and JAX is installed:

neuron-ls
python3 -c "import jax; print(f'JAX version: {jax.__version__}'); print(f'Devices: {jax.devices()}')"
⚠️ Troubleshooting: Device not found in container

If neuron-ls shows no devices inside the container:

  1. Verify the Neuron driver is installed on the host

  2. Confirm you passed the correct --device flag: ls /dev/neuron*

  3. Restart the container with the correct device path

⚠️ Troubleshooting: Permission denied

If you see permission denied errors when running Docker commands:

  1. Verify your user is in the docker group: groups

  2. If not, add yourself: sudo usermod -aG docker $USER and re-login

  3. Alternatively, run Docker with sudo

⚠️ Troubleshooting: Image pull failure

If docker pull fails with a network or authentication error:

  1. Verify internet connectivity: curl -s https://public.ecr.aws/v2/ | head -1

  2. Check that the image tag exists in the ECR Public Gallery

  3. If behind a proxy, configure Docker proxy settings

Step 1: Install Neuron driver on host

Configure the Neuron repository and install the driver:

sudo tee /etc/yum.repos.d/neuron.repo > /dev/null <<EOF
[neuron]
name=Neuron YUM Repository
baseurl=https://yum.repos.neuron.amazonaws.com
enabled=1
metadata_expire=0
EOF
sudo rpm --import https://yum.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB
sudo dnf update -y
sudo dnf install -y "kernel-devel-uname-r == $(uname -r)"
sudo dnf install -y aws-neuronx-dkms

Step 2: Install and verify Docker

Install Docker and add your user to the docker group:

sudo dnf install -y docker
sudo usermod -aG docker $USER

Log out and log back in to refresh group membership, then verify:

docker run hello-world

Step 3: Pull the DLC image from ECR

Pull the JAX Training DLC image:

docker pull public.ecr.aws/neuron/jax-training-neuronx:<image_tag>

Replace <image_tag> with the desired tag from the JAX Training Containers repository.

Step 4: Run the container

Launch the container with access to Neuron devices:

docker run -it \
  --device=/dev/neuron0 \
  --cap-add SYS_ADMIN \
  --cap-add IPC_LOCK \
  public.ecr.aws/neuron/jax-training-neuronx:<image_tag> \
  bash

Note

Adjust the --device flags based on your instance type. Use ls /dev/neuron* on the host to list available devices. For example, a trn1.32xlarge has 16 devices (/dev/neuron0 through /dev/neuron15).

Step 5: Verify inside the container

Run the following commands inside the container to confirm Neuron devices are visible and JAX is installed:

neuron-ls
python3 -c "import jax; print(f'JAX version: {jax.__version__}'); print(f'Devices: {jax.devices()}')"
⚠️ Troubleshooting: Device not found in container

If neuron-ls shows no devices inside the container:

  1. Verify the Neuron driver is installed on the host

  2. Confirm you passed the correct --device flag: ls /dev/neuron*

  3. Restart the container with the correct device path

⚠️ Troubleshooting: Permission denied

If you see permission denied errors when running Docker commands:

  1. Verify your user is in the docker group: groups

  2. If not, add yourself: sudo usermod -aG docker $USER and re-login

  3. Alternatively, run Docker with sudo

⚠️ Troubleshooting: Image pull failure

If docker pull fails with a network or authentication error:

  1. Verify internet connectivity: curl -s https://public.ecr.aws/v2/ | head -1

  2. Check that the image tag exists in the ECR Public Gallery

  3. If behind a proxy, configure Docker proxy settings

Next steps#

Now that JAX is running in a container:

  1. Find more container images: Browse the full list of available Neuron DLC images at Neuron Deep Learning Containers.

  2. Customize your container: Learn how to extend a DLC with additional packages at Customize Neuron DLC.

  3. Read the JAX documentation: Explore the JAX Support on Neuron for JAX framework documentation and tutorials.

Additional resources#

This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3