This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3
Install JAX Manually#
Install JAX with Neuron support on existing systems using pip.
⏱️ Estimated time: 15 minutes
Prerequisites#
Requirement |
Details |
|---|---|
Instance Type |
Inf2, Trn1, Trn2, or Trn3 |
Operating System |
Ubuntu 24.04, Ubuntu 22.04, or Amazon Linux 2023 |
Python |
Python 3.10, 3.11, or 3.12 |
Sudo Access |
Required for driver installation |
Internet Access |
For downloading packages |
Installation Steps#
Step 1: Update System Packages
sudo apt-get update
sudo apt-get install -y python3-pip python3-venv
Step 2: Configure Neuron Repository
# Add Neuron repository
. /etc/os-release
sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null <<EOF
deb https://apt.repos.neuron.amazonaws.com ${VERSION_CODENAME} main
EOF
# Add repository key
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
# Update package list
sudo apt-get update
Step 3: Install Neuron Driver and Runtime
sudo apt-get install -y aws-neuronx-dkms
sudo apt-get install -y aws-neuronx-runtime-lib
sudo apt-get install -y aws-neuronx-collectives
Step 4: Create Virtual Environment
python3.10 -m venv ~/neuron_venv_jax
source ~/neuron_venv_jax/bin/activate
Step 5: Install JAX and Neuron Packages
pip install -U pip
pip install jax-neuronx[stable] --extra-index-url=https://pip.repos.neuron.amazonaws.com
Step 6: Verify Installation
python3 << EOF
import jax
import jax_neuronx
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
# Check Neuron devices
import subprocess
result = subprocess.run(['neuron-ls'], capture_output=True, text=True)
print(result.stdout)
EOF
⚠️ Troubleshooting: GPG key error
If you see “EXPKEYSIG” error during apt-get update:
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
sudo apt-get update -y
⚠️ Troubleshooting: Driver installation failed
If driver installation fails:
Check kernel headers are installed:
sudo apt-get install -y linux-headers-$(uname -r)
Retry driver installation:
sudo apt-get install --reinstall aws-neuronx-dkms
Step 1: Update System Packages
sudo apt-get update
sudo apt-get install -y python3-pip python3-venv
Step 2: Configure Neuron Repository
# Add Neuron repository
. /etc/os-release
sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null <<EOF
deb https://apt.repos.neuron.amazonaws.com ${VERSION_CODENAME} main
EOF
# Add repository key
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
# Update package list
sudo apt-get update
Step 3: Install Neuron Driver and Runtime
sudo apt-get install -y aws-neuronx-dkms
sudo apt-get install -y aws-neuronx-runtime-lib
sudo apt-get install -y aws-neuronx-collectives
Step 4: Create Virtual Environment
python3.10 -m venv ~/neuron_venv_jax
source ~/neuron_venv_jax/bin/activate
Step 5: Install JAX and Neuron Packages
pip install -U pip
pip install jax-neuronx[stable] --extra-index-url=https://pip.repos.neuron.amazonaws.com
Step 6: Verify Installation
python3 << EOF
import jax
import jax_neuronx
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
# Check Neuron devices
import subprocess
result = subprocess.run(['neuron-ls'], capture_output=True, text=True)
print(result.stdout)
EOF
⚠️ Troubleshooting: GPG key error
If you see “EXPKEYSIG” error, update the GPG key and retry.
⚠️ Troubleshooting: Driver installation failed
Ensure kernel headers are installed before retrying driver installation.
Step 1: Update System Packages
sudo yum update -y
sudo yum install -y python3-pip python3-devel
Step 2: Configure Neuron Repository
# Add Neuron repository
sudo tee /etc/yum.repos.d/neuron.repo > /dev/null <<EOF
[neuron]
name=Neuron YUM Repository
baseurl=https://yum.repos.neuron.amazonaws.com
enabled=1
metadata_expire=0
EOF
# Import GPG key
sudo rpm --import https://yum.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB
Step 3: Install Neuron Driver and Runtime
sudo yum install -y aws-neuronx-dkms
sudo yum install -y aws-neuronx-runtime-lib
sudo yum install -y aws-neuronx-collectives
Step 4: Create Virtual Environment
python3.10 -m venv ~/neuron_venv_jax
source ~/neuron_venv_jax/bin/activate
Step 5: Install JAX and Neuron Packages
pip install -U pip
pip install jax-neuronx[stable] --extra-index-url=https://pip.repos.neuron.amazonaws.com
Step 6: Verify Installation
python3 << EOF
import jax
import jax_neuronx
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
# Check Neuron devices
import subprocess
result = subprocess.run(['neuron-ls'], capture_output=True, text=True)
print(result.stdout)
EOF
⚠️ Troubleshooting: Repository access error
If you cannot access the Neuron repository:
Verify network connectivity
Check proxy settings if behind corporate firewall
Ensure GPG key is imported correctly
⚠️ Troubleshooting: Driver installation failed
Ensure kernel-devel package is installed:
sudo yum install -y kernel-devel-$(uname -r)
Next Steps#
Now that JAX is installed:
Try a Quick Example:
import jax import jax.numpy as jnp # Simple operation on Neuron x = jnp.array([1.0, 2.0, 3.0]) y = jnp.array([4.0, 5.0, 6.0]) result = jax.numpy.multiply(x, y) print(result)
Read Documentation:
Explore Setup Guide:
Additional Resources#
Install JAX via Deep Learning AMI - Use pre-configured DLAMI instead
Install JAX via Deep Learning Container - Use pre-configured Docker containers
Neuron Containers - Container-based deployment
Installation Troubleshooting - Common issues and solutions
AWS Neuron SDK Release Notes - Version compatibility information
This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3