This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3

Install JAX Manually#

Install JAX with Neuron support on existing systems using pip.

⏱️ Estimated time: 15 minutes

Prerequisites#

Requirement

Details

Instance Type

Inf2, Trn1, Trn2, or Trn3

Operating System

Ubuntu 24.04, Ubuntu 22.04, or Amazon Linux 2023

Python

Python 3.10, 3.11, or 3.12

Sudo Access

Required for driver installation

Internet Access

For downloading packages

Installation Steps#

Step 1: Update System Packages

sudo apt-get update
sudo apt-get install -y python3-pip python3-venv

Step 2: Configure Neuron Repository

# Add Neuron repository
. /etc/os-release
sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null <<EOF
deb https://apt.repos.neuron.amazonaws.com ${VERSION_CODENAME} main
EOF

# Add repository key
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -

# Update package list
sudo apt-get update

Step 3: Install Neuron Driver and Runtime

sudo apt-get install -y aws-neuronx-dkms
sudo apt-get install -y aws-neuronx-runtime-lib
sudo apt-get install -y aws-neuronx-collectives

Step 4: Create Virtual Environment

python3.10 -m venv ~/neuron_venv_jax
source ~/neuron_venv_jax/bin/activate

Step 5: Install JAX and Neuron Packages

pip install -U pip
pip install jax-neuronx[stable] --extra-index-url=https://pip.repos.neuron.amazonaws.com

Step 6: Verify Installation

python3 << EOF
import jax
import jax_neuronx

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

# Check Neuron devices
import subprocess
result = subprocess.run(['neuron-ls'], capture_output=True, text=True)
print(result.stdout)
EOF
⚠️ Troubleshooting: GPG key error

If you see “EXPKEYSIG” error during apt-get update:

wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
sudo apt-get update -y
⚠️ Troubleshooting: Driver installation failed

If driver installation fails:

  1. Check kernel headers are installed:

    sudo apt-get install -y linux-headers-$(uname -r)
    
  2. Retry driver installation:

    sudo apt-get install --reinstall aws-neuronx-dkms
    

Step 1: Update System Packages

sudo apt-get update
sudo apt-get install -y python3-pip python3-venv

Step 2: Configure Neuron Repository

# Add Neuron repository
. /etc/os-release
sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null <<EOF
deb https://apt.repos.neuron.amazonaws.com ${VERSION_CODENAME} main
EOF

# Add repository key
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -

# Update package list
sudo apt-get update

Step 3: Install Neuron Driver and Runtime

sudo apt-get install -y aws-neuronx-dkms
sudo apt-get install -y aws-neuronx-runtime-lib
sudo apt-get install -y aws-neuronx-collectives

Step 4: Create Virtual Environment

python3.10 -m venv ~/neuron_venv_jax
source ~/neuron_venv_jax/bin/activate

Step 5: Install JAX and Neuron Packages

pip install -U pip
pip install jax-neuronx[stable] --extra-index-url=https://pip.repos.neuron.amazonaws.com

Step 6: Verify Installation

python3 << EOF
import jax
import jax_neuronx

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

# Check Neuron devices
import subprocess
result = subprocess.run(['neuron-ls'], capture_output=True, text=True)
print(result.stdout)
EOF
⚠️ Troubleshooting: GPG key error

If you see “EXPKEYSIG” error, update the GPG key and retry.

⚠️ Troubleshooting: Driver installation failed

Ensure kernel headers are installed before retrying driver installation.

Step 1: Update System Packages

sudo yum update -y
sudo yum install -y python3-pip python3-devel

Step 2: Configure Neuron Repository

# Add Neuron repository
sudo tee /etc/yum.repos.d/neuron.repo > /dev/null <<EOF
[neuron]
name=Neuron YUM Repository
baseurl=https://yum.repos.neuron.amazonaws.com
enabled=1
metadata_expire=0
EOF

# Import GPG key
sudo rpm --import https://yum.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB

Step 3: Install Neuron Driver and Runtime

sudo yum install -y aws-neuronx-dkms
sudo yum install -y aws-neuronx-runtime-lib
sudo yum install -y aws-neuronx-collectives

Step 4: Create Virtual Environment

python3.10 -m venv ~/neuron_venv_jax
source ~/neuron_venv_jax/bin/activate

Step 5: Install JAX and Neuron Packages

pip install -U pip
pip install jax-neuronx[stable] --extra-index-url=https://pip.repos.neuron.amazonaws.com

Step 6: Verify Installation

python3 << EOF
import jax
import jax_neuronx

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

# Check Neuron devices
import subprocess
result = subprocess.run(['neuron-ls'], capture_output=True, text=True)
print(result.stdout)
EOF
⚠️ Troubleshooting: Repository access error

If you cannot access the Neuron repository:

  1. Verify network connectivity

  2. Check proxy settings if behind corporate firewall

  3. Ensure GPG key is imported correctly

⚠️ Troubleshooting: Driver installation failed

Ensure kernel-devel package is installed:

sudo yum install -y kernel-devel-$(uname -r)

Next Steps#

Now that JAX is installed:

  1. Try a Quick Example:

    import jax
    import jax.numpy as jnp
    
    # Simple operation on Neuron
    x = jnp.array([1.0, 2.0, 3.0])
    y = jnp.array([4.0, 5.0, 6.0])
    result = jax.numpy.multiply(x, y)
    print(result)
    
  2. Read Documentation:

  3. Explore Setup Guide:

Additional Resources#

This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3