This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3
Install JAX Manually#
Install JAX with Neuron support on existing systems using pip.
⏱️ Estimated time: 15 minutes
Prerequisites#
Requirement |
Details |
|---|---|
Instance Type |
Inf2, Trn1, Trn2, or Trn3 |
Operating System |
Ubuntu 24.04, Ubuntu 22.04, or Amazon Linux 2023 |
Python |
Python 3.11, or 3.12 |
Sudo Access |
Required for driver installation |
Internet Access |
For downloading packages |
Installation Steps#
Step 1: Update System Packages
sudo apt-get update
sudo apt-get install -y python3-pip python3-venv
Step 2: Configure Neuron Repository
# Add Neuron repository
. /etc/os-release
sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null <<EOF
deb https://apt.repos.neuron.amazonaws.com ${VERSION_CODENAME} main
EOF
# Add repository key
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
# Update package list
sudo apt-get update
Step 3: Install Neuron Driver and Runtime
sudo apt-get install -y aws-neuronx-dkms
sudo apt-get install -y aws-neuronx-runtime-lib
sudo apt-get install -y aws-neuronx-collectives
Step 4: Create Virtual Environment
python3.11 -m venv ~/neuron_venv_jax
source ~/neuron_venv_jax/bin/activate
Step 5: Install JAX and Neuron Packages
pip install -U pip
pip install jax-neuronx[stable] --extra-index-url=https://pip.repos.neuron.amazonaws.com
Step 6: Verify Installation
python3 << EOF
import jax
import jax_neuronx
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
# Check Neuron devices
import subprocess
result = subprocess.run(['neuron-ls'], capture_output=True, text=True)
print(result.stdout)
EOF
⚠️ Troubleshooting: GPG key error
If you see “EXPKEYSIG” error during apt-get update:
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
sudo apt-get update -y
⚠️ Troubleshooting: Driver installation failed
If driver installation fails:
Check kernel headers are installed:
sudo apt-get install -y linux-headers-$(uname -r)
Retry driver installation:
sudo apt-get install --reinstall aws-neuronx-dkms
Step 1: Update System Packages
Important
Ubuntu 22.04 has reached end-of-support on Neuron. Neuron no longer provides Ubuntu 22.04 DLAMIs or container images. New deployments should use Ubuntu 24.04. See Neuron no longer includes Ubuntu 22.04 DLAMIs and DLCs starting this release.
sudo apt-get update
sudo apt-get install -y python3-pip python3-venv
Step 2: Configure Neuron Repository
# Add Neuron repository
. /etc/os-release
sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null <<EOF
deb https://apt.repos.neuron.amazonaws.com ${VERSION_CODENAME} main
EOF
# Add repository key
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
# Update package list
sudo apt-get update
Step 3: Install Neuron Driver and Runtime
sudo apt-get install -y aws-neuronx-dkms
sudo apt-get install -y aws-neuronx-runtime-lib
sudo apt-get install -y aws-neuronx-collectives
Step 4: Create Virtual Environment
python -m venv ~/neuron_venv_jax
source ~/neuron_venv_jax/bin/activate
Step 5: Install JAX and Neuron Packages
pip install -U pip
pip install jax-neuronx[stable] --extra-index-url=https://pip.repos.neuron.amazonaws.com
Step 6: Verify Installation
python3 << EOF
import jax
import jax_neuronx
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
# Check Neuron devices
import subprocess
result = subprocess.run(['neuron-ls'], capture_output=True, text=True)
print(result.stdout)
EOF
⚠️ Troubleshooting: GPG key error
If you see “EXPKEYSIG” error, update the GPG key and retry.
⚠️ Troubleshooting: Driver installation failed
Ensure kernel headers are installed before retrying driver installation.
Step 1: Update System Packages
sudo yum update -y
sudo yum install -y python3-pip python3-devel
Step 2: Configure Neuron Repository
# Add Neuron repository
sudo tee /etc/yum.repos.d/neuron.repo > /dev/null <<EOF
[neuron]
name=Neuron YUM Repository
baseurl=https://yum.repos.neuron.amazonaws.com
enabled=1
metadata_expire=0
EOF
# Import GPG key
sudo rpm --import https://yum.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB
Step 3: Install Neuron Driver and Runtime
sudo yum install -y aws-neuronx-dkms
sudo yum install -y aws-neuronx-runtime-lib
sudo yum install -y aws-neuronx-collectives
Step 4: Create Virtual Environment
python -m venv ~/neuron_venv_jax
source ~/neuron_venv_jax/bin/activate
Step 5: Install JAX and Neuron Packages
pip install -U pip
pip install jax-neuronx[stable] --extra-index-url=https://pip.repos.neuron.amazonaws.com
Step 6: Verify Installation
python3 << EOF
import jax
import jax_neuronx
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
# Check Neuron devices
import subprocess
result = subprocess.run(['neuron-ls'], capture_output=True, text=True)
print(result.stdout)
EOF
⚠️ Troubleshooting: Repository access error
If you cannot access the Neuron repository:
Verify network connectivity
Check proxy settings if behind corporate firewall
Ensure GPG key is imported correctly
⚠️ Troubleshooting: Driver installation failed
Ensure kernel-devel package is installed:
sudo yum install -y kernel-devel-$(uname -r)
Next Steps#
Now that JAX is installed:
Try a Quick Example:
import jax import jax.numpy as jnp # Simple operation on Neuron x = jnp.array([1.0, 2.0, 3.0]) y = jnp.array([4.0, 5.0, 6.0]) result = jax.numpy.multiply(x, y) print(result)
Read Documentation:
Explore Setup Guide:
Additional Resources#
Install JAX via Deep Learning AMI - Use pre-configured DLAMI instead
Install JAX via Deep Learning Container - Use pre-configured Docker containers
Deploy on AWS - Container-based deployment
Installation Troubleshooting - Common issues and solutions
AWS Neuron SDK Release Notes - Version compatibility information
This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3