This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3

Install JAX via Deep Learning AMI#

Install JAX with Neuron support using pre-configured AWS Deep Learning AMIs.

⏱️ Estimated time: 5 minutes

Note

Want to read about Neuron’s Deep Learning machine images (DLAMIs) before diving in? Check out the Neuron DLAMI User Guide.


Prerequisites#

Requirement

Details

Instance Type

Inf2, Trn1, Trn2, or Trn3

AWS Account

With EC2 permissions

SSH Key Pair

For instance access

AWS CLI

Configured with credentials (optional)

Installation Steps#

Step 1: Find the Latest AMI

Get the latest JAX DLAMI for Ubuntu 24.04:

aws ec2 describe-images \
  --owners amazon \
  --filters "Name=name,Values=Deep Learning AMI Neuron JAX * (Ubuntu 24.04)*" \
  --query 'Images | sort_by(@, &CreationDate) | [-1].ImageId' \
  --output text

Step 2: Launch Instance

Launch a Trn1 or Inf2 instance with the AMI:

aws ec2 run-instances \
  --image-id ami-xxxxxxxxxxxxxxxxx \
  --instance-type trn1.2xlarge \
  --key-name your-key-pair \
  --security-group-ids sg-xxxxxxxxx \
  --subnet-id subnet-xxxxxxxxx

Replace:

  • ami-xxxxxxxxxxxxxxxxx with AMI ID from Step 1

  • your-key-pair with your SSH key pair name

  • sg-xxxxxxxxx with your security group ID

  • subnet-xxxxxxxxx with your subnet ID

Step 3: Connect to Instance

ssh -i your-key-pair.pem ubuntu@<instance-public-ip>

Step 4: Activate Environment

The DLAMI includes a pre-configured virtual environment:

source /opt/aws_neuronx_venv_jax/bin/activate

Step 5: Verify Installation

python3 << EOF
import jax
import jax_neuronx

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

# Check Neuron devices
import subprocess
result = subprocess.run(['neuron-ls'], capture_output=True, text=True)
print(result.stdout)
EOF

Expected output:

JAX version: 0.7.0
Devices: [NeuronDevice(id=0), NeuronDevice(id=1)]

+--------+--------+--------+-----------+
| DEVICE | CORES  | MEMORY | CONNECTED |
+--------+--------+--------+-----------+
| 0      | 2      | 32 GB  | Yes       |
| 1      | 2      | 32 GB  | Yes       |
+--------+--------+--------+-----------+
⚠️ Troubleshooting: Module not found

If you see ModuleNotFoundError: No module named 'jax_neuronx':

  1. Verify virtual environment is activated:

    which python
    # Should show: /opt/aws_neuronx_venv_jax/bin/python
    
  2. Check Python version:

    python --version
    # Should be 3.11 or higher
    
  3. Reinstall jax-neuronx:

    pip install --force-reinstall jax-neuronx
    
⚠️ Troubleshooting: No Neuron devices found

If neuron-ls shows no devices:

  1. Verify instance type:

    curl http://169.254.169.254/latest/meta-data/instance-type
    # Should show trn1.*, trn2.*, trn3.*, or inf2.*
    
  2. Check Neuron driver:

    lsmod | grep neuron
    # Should show neuron driver loaded
    
  3. Restart Neuron runtime:

    sudo systemctl restart neuron-monitor
    neuron-ls
    

Step 1: Find the Latest AMI

Important

Ubuntu 22.04 has reached end-of-support on Neuron. Neuron no longer provides Ubuntu 22.04 DLAMIs or container images. New deployments should use Ubuntu 24.04. See Neuron no longer includes Ubuntu 22.04 DLAMIs and DLCs starting this release.

Get the latest JAX DLAMI for Ubuntu 22.04:

aws ec2 describe-images \
  --owners amazon \
  --filters "Name=name,Values=Deep Learning AMI Neuron JAX * (Ubuntu 22.04)*" \
  --query 'Images | sort_by(@, &CreationDate) | [-1].ImageId' \
  --output text

Step 2: Launch Instance

aws ec2 run-instances \
  --image-id ami-xxxxxxxxxxxxxxxxx \
  --instance-type trn1.2xlarge \
  --key-name your-key-pair \
  --security-group-ids sg-xxxxxxxxx \
  --subnet-id subnet-xxxxxxxxx

Step 3: Connect to Instance

ssh -i your-key-pair.pem ubuntu@<instance-public-ip>

Step 4: Activate Environment

source /opt/aws_neuronx_venv_jax/bin/activate

Step 5: Verify Installation

python3 << EOF
import jax
import jax_neuronx

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

# Check Neuron devices
import subprocess
result = subprocess.run(['neuron-ls'], capture_output=True, text=True)
print(result.stdout)
EOF
⚠️ Troubleshooting: Module not found

If you see ModuleNotFoundError: No module named 'jax_neuronx':

  1. Verify virtual environment is activated

  2. Check Python version: python --version (should be 3.11+)

  3. Reinstall: pip install --force-reinstall jax-neuronx

⚠️ Troubleshooting: No Neuron devices found

If neuron-ls shows no devices:

  1. Verify instance type

  2. Check Neuron driver: lsmod | grep neuron

  3. Restart runtime: sudo systemctl restart neuron-monitor

Step 1: Find the Latest AMI

Get the latest JAX DLAMI for Amazon Linux 2023:

aws ec2 describe-images \
  --owners amazon \
  --filters "Name=name,Values=Deep Learning AMI Neuron JAX * (Amazon Linux 2023)*" \
  --query 'Images | sort_by(@, &CreationDate) | [-1].ImageId' \
  --output text

Step 2: Launch Instance

aws ec2 run-instances \
  --image-id ami-xxxxxxxxxxxxxxxxx \
  --instance-type trn1.2xlarge \
  --key-name your-key-pair \
  --security-group-ids sg-xxxxxxxxx \
  --subnet-id subnet-xxxxxxxxx

Step 3: Connect to Instance

ssh -i your-key-pair.pem ec2-user@<instance-public-ip>

Note

Amazon Linux 2023 uses ec2-user instead of ubuntu.

Step 4: Activate Environment

source /opt/aws_neuronx_venv_jax/bin/activate

Step 5: Verify Installation

python3 << EOF
import jax
import jax_neuronx

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

# Check Neuron devices
import subprocess
result = subprocess.run(['neuron-ls'], capture_output=True, text=True)
print(result.stdout)
EOF
⚠️ Troubleshooting: Module not found

If you see ModuleNotFoundError: No module named 'jax_neuronx':

  1. Verify virtual environment is activated

  2. Check Python version: python --version (should be 3.11+)

  3. Reinstall: pip install --force-reinstall jax-neuronx

⚠️ Troubleshooting: No Neuron devices found

If neuron-ls shows no devices:

  1. Verify instance type

  2. Check Neuron driver: lsmod | grep neuron

  3. Restart runtime: sudo systemctl restart neuron-monitor

Next Steps#

Now that JAX is installed:

  1. Try a Quick Example:

    import jax
    import jax.numpy as jnp
    
    # Simple operation on Neuron
    x = jnp.array([1.0, 2.0, 3.0])
    y = jnp.array([4.0, 5.0, 6.0])
    result = jax.numpy.multiply(x, y)
    print(result)
    
  2. Read Documentation:

  3. Explore Setup Guide:

Additional Resources#

This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3