This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3
Install JAX via Deep Learning Container#
Install JAX with Neuron support using pre-configured AWS Deep Learning Containers (DLCs).
⏱️ Estimated time: ~10 minutes
Prerequisites#
Requirement |
Details |
|---|---|
Instance Type |
Inf2, Trn1, Trn2, or Trn3 |
Neuron Driver on Host |
|
Docker Installed |
Docker engine running on the host instance |
AWS Account |
With EC2 permissions |
Available container images#
Image |
ECR URI |
|---|---|
JAX Training |
|
Note
JAX DLCs are currently available for training workloads. For the full list of available images and tags, see JAX Training Containers.
For more information, see Neuron Deep Learning Containers.
Installation steps#
Step 1: Install Neuron driver on host
Configure the Neuron repository and install the driver:
. /etc/os-release
sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null <<EOF
deb https://apt.repos.neuron.amazonaws.com ${VERSION_CODENAME} main
EOF
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
sudo apt-get update
sudo apt-get install -y aws-neuronx-dkms
Step 2: Install and verify Docker
Install Docker and add your user to the docker group:
sudo apt-get install -y docker.io
sudo usermod -aG docker $USER
Log out and log back in to refresh group membership, then verify:
docker run hello-world
Step 3: Pull the DLC image from ECR
Pull the JAX Training DLC image:
docker pull public.ecr.aws/neuron/jax-training-neuronx:<image_tag>
Replace <image_tag> with the desired tag from the JAX Training Containers repository.
Step 4: Run the container
Launch the container with access to Neuron devices:
docker run -it \
--device=/dev/neuron0 \
--cap-add SYS_ADMIN \
--cap-add IPC_LOCK \
public.ecr.aws/neuron/jax-training-neuronx:<image_tag> \
bash
Note
Adjust the --device flags based on your instance type. Use ls /dev/neuron* on the host to list available devices. For example, a trn1.32xlarge has 16 devices (/dev/neuron0 through /dev/neuron15).
Step 5: Verify inside the container
Run the following commands inside the container to confirm Neuron devices are visible and JAX is installed:
neuron-ls
python3 -c "import jax; print(f'JAX version: {jax.__version__}'); print(f'Devices: {jax.devices()}')"
Expected output:
+--------+--------+--------+-----------+
| DEVICE | CORES | MEMORY | CONNECTED |
+--------+--------+--------+-----------+
| 0 | 2 | 32 GB | Yes |
+--------+--------+--------+-----------+
JAX version: 0.7.0
Devices: [NeuronDevice(id=0), NeuronDevice(id=1)]
⚠️ Troubleshooting: Device not found in container
If neuron-ls shows no devices inside the container:
Verify the Neuron driver is installed on the host:
# Run on the host (not inside the container) neuron-lsConfirm you passed the correct
--deviceflag:ls /dev/neuron*Restart the container with the correct device path:
docker run -it --device=/dev/neuron0 \ --cap-add SYS_ADMIN --cap-add IPC_LOCK \ public.ecr.aws/neuron/jax-training-neuronx:<image_tag> bash
⚠️ Troubleshooting: Permission denied
If you see permission denied errors when running Docker commands:
Verify your user is in the
dockergroup:groups # Should include "docker"If not, add yourself and re-login:
sudo usermod -aG docker $USER # Log out and log back in
Alternatively, run Docker with
sudo:sudo docker run -it --device=/dev/neuron0 \ --cap-add SYS_ADMIN --cap-add IPC_LOCK \ public.ecr.aws/neuron/jax-training-neuronx:<image_tag> bash
⚠️ Troubleshooting: Image pull failure
If docker pull fails with a network or authentication error:
Verify internet connectivity:
curl -s https://public.ecr.aws/v2/ | head -1
Check that the image tag exists by browsing the ECR Public Gallery.
If you are behind a proxy, configure Docker proxy settings:
sudo mkdir -p /etc/systemd/system/docker.service.d sudo tee /etc/systemd/system/docker.service.d/proxy.conf > /dev/null <<EOF [Service] Environment="HTTP_PROXY=http://proxy:port" Environment="HTTPS_PROXY=http://proxy:port" EOF sudo systemctl daemon-reload sudo systemctl restart docker
Step 1: Install Neuron driver on host
Configure the Neuron repository and install the driver:
. /etc/os-release
sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null <<EOF
deb https://apt.repos.neuron.amazonaws.com ${VERSION_CODENAME} main
EOF
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
sudo apt-get update
sudo apt-get install -y aws-neuronx-dkms
Step 2: Install and verify Docker
Install Docker and add your user to the docker group:
sudo apt-get install -y docker.io
sudo usermod -aG docker $USER
Log out and log back in to refresh group membership, then verify:
docker run hello-world
Step 3: Pull the DLC image from ECR
Pull the JAX Training DLC image:
docker pull public.ecr.aws/neuron/jax-training-neuronx:<image_tag>
Replace <image_tag> with the desired tag from the JAX Training Containers repository.
Step 4: Run the container
Launch the container with access to Neuron devices:
docker run -it \
--device=/dev/neuron0 \
--cap-add SYS_ADMIN \
--cap-add IPC_LOCK \
public.ecr.aws/neuron/jax-training-neuronx:<image_tag> \
bash
Note
Adjust the --device flags based on your instance type. Use ls /dev/neuron* on the host to list available devices. For example, a trn1.32xlarge has 16 devices (/dev/neuron0 through /dev/neuron15).
Step 5: Verify inside the container
Run the following commands inside the container to confirm Neuron devices are visible and JAX is installed:
neuron-ls
python3 -c "import jax; print(f'JAX version: {jax.__version__}'); print(f'Devices: {jax.devices()}')"
⚠️ Troubleshooting: Device not found in container
If neuron-ls shows no devices inside the container:
Verify the Neuron driver is installed on the host
Confirm you passed the correct
--deviceflag:ls /dev/neuron*Restart the container with the correct device path
⚠️ Troubleshooting: Permission denied
If you see permission denied errors when running Docker commands:
Verify your user is in the
dockergroup:groupsIf not, add yourself:
sudo usermod -aG docker $USERand re-loginAlternatively, run Docker with
sudo
⚠️ Troubleshooting: Image pull failure
If docker pull fails with a network or authentication error:
Verify internet connectivity:
curl -s https://public.ecr.aws/v2/ | head -1Check that the image tag exists in the ECR Public Gallery
If behind a proxy, configure Docker proxy settings
Step 1: Install Neuron driver on host
Configure the Neuron repository and install the driver:
sudo tee /etc/yum.repos.d/neuron.repo > /dev/null <<EOF
[neuron]
name=Neuron YUM Repository
baseurl=https://yum.repos.neuron.amazonaws.com
enabled=1
metadata_expire=0
EOF
sudo rpm --import https://yum.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB
sudo dnf update -y
sudo dnf install -y "kernel-devel-uname-r == $(uname -r)"
sudo dnf install -y aws-neuronx-dkms
Step 2: Install and verify Docker
Install Docker and add your user to the docker group:
sudo dnf install -y docker
sudo usermod -aG docker $USER
Log out and log back in to refresh group membership, then verify:
docker run hello-world
Step 3: Pull the DLC image from ECR
Pull the JAX Training DLC image:
docker pull public.ecr.aws/neuron/jax-training-neuronx:<image_tag>
Replace <image_tag> with the desired tag from the JAX Training Containers repository.
Step 4: Run the container
Launch the container with access to Neuron devices:
docker run -it \
--device=/dev/neuron0 \
--cap-add SYS_ADMIN \
--cap-add IPC_LOCK \
public.ecr.aws/neuron/jax-training-neuronx:<image_tag> \
bash
Note
Adjust the --device flags based on your instance type. Use ls /dev/neuron* on the host to list available devices. For example, a trn1.32xlarge has 16 devices (/dev/neuron0 through /dev/neuron15).
Step 5: Verify inside the container
Run the following commands inside the container to confirm Neuron devices are visible and JAX is installed:
neuron-ls
python3 -c "import jax; print(f'JAX version: {jax.__version__}'); print(f'Devices: {jax.devices()}')"
⚠️ Troubleshooting: Device not found in container
If neuron-ls shows no devices inside the container:
Verify the Neuron driver is installed on the host
Confirm you passed the correct
--deviceflag:ls /dev/neuron*Restart the container with the correct device path
⚠️ Troubleshooting: Permission denied
If you see permission denied errors when running Docker commands:
Verify your user is in the
dockergroup:groupsIf not, add yourself:
sudo usermod -aG docker $USERand re-loginAlternatively, run Docker with
sudo
⚠️ Troubleshooting: Image pull failure
If docker pull fails with a network or authentication error:
Verify internet connectivity:
curl -s https://public.ecr.aws/v2/ | head -1Check that the image tag exists in the ECR Public Gallery
If behind a proxy, configure Docker proxy settings
Next steps#
Now that JAX is running in a container:
Find more container images: Browse the full list of available Neuron DLC images at Neuron Deep Learning Containers.
Customize your container: Learn how to extend a DLC with additional packages at Customize Neuron DLC.
Read the JAX documentation: Explore the JAX Support on Neuron for JAX framework documentation and tutorials.
Additional resources#
Neuron Deep Learning Containers - Full DLC image list
Neuron Containers - Container documentation overview
Installation Troubleshooting - Common issues and solutions
AWS Neuron SDK Release Notes - Version compatibility information
This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3