This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3

Install JAX via Deep Learning AMI#

Install JAX with Neuron support using pre-configured AWS Deep Learning AMIs.

⏱️ Estimated time: 5 minutes

Note

Want to read about Neuron’s Deep Learning machine images (DLAMIs) before diving in? Check out the Neuron DLAMI User Guide.


Prerequisites#

Requirement

Details

Instance Type

Inf2, Trn1, Trn2, or Trn3

AWS Account

With EC2 permissions

SSH Key Pair

For instance access

AWS CLI

Configured with credentials (optional)

Installation Steps#

Step 1: Find the Latest AMI

Get the latest JAX DLAMI for Ubuntu 24.04:

aws ec2 describe-images \
  --owners amazon \
  --filters "Name=name,Values=Deep Learning AMI Neuron JAX * (Ubuntu 24.04)*" \
  --query 'Images | sort_by(@, &CreationDate) | [-1].ImageId' \
  --output text

Step 2: Launch Instance

Launch a Trn1 or Inf2 instance with the AMI:

aws ec2 run-instances \
  --image-id ami-xxxxxxxxxxxxxxxxx \
  --instance-type trn1.2xlarge \
  --key-name your-key-pair \
  --security-group-ids sg-xxxxxxxxx \
  --subnet-id subnet-xxxxxxxxx

Replace:

  • ami-xxxxxxxxxxxxxxxxx with AMI ID from Step 1

  • your-key-pair with your SSH key pair name

  • sg-xxxxxxxxx with your security group ID

  • subnet-xxxxxxxxx with your subnet ID

Step 3: Connect to Instance

ssh -i your-key-pair.pem ubuntu@<instance-public-ip>

Step 4: Activate Environment

The DLAMI includes a pre-configured virtual environment:

source /opt/aws_neuronx_venv_jax/bin/activate

Step 5: Verify Installation

python3 << EOF
import jax
import jax_neuronx

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

# Check Neuron devices
import subprocess
result = subprocess.run(['neuron-ls'], capture_output=True, text=True)
print(result.stdout)
EOF

Expected output:

JAX version: 0.7.0
Devices: [NeuronDevice(id=0), NeuronDevice(id=1)]

+--------+--------+--------+-----------+
| DEVICE | CORES  | MEMORY | CONNECTED |
+--------+--------+--------+-----------+
| 0      | 2      | 32 GB  | Yes       |
| 1      | 2      | 32 GB  | Yes       |
+--------+--------+--------+-----------+
⚠️ Troubleshooting: Module not found

If you see ModuleNotFoundError: No module named 'jax_neuronx':

  1. Verify virtual environment is activated:

    which python
    # Should show: /opt/aws_neuronx_venv_jax/bin/python
    
  2. Check Python version:

    python --version
    # Should be 3.10 or higher
    
  3. Reinstall jax-neuronx:

    pip install --force-reinstall jax-neuronx
    
⚠️ Troubleshooting: No Neuron devices found

If neuron-ls shows no devices:

  1. Verify instance type:

    curl http://169.254.169.254/latest/meta-data/instance-type
    # Should show trn1.*, trn2.*, trn3.*, or inf2.*
    
  2. Check Neuron driver:

    lsmod | grep neuron
    # Should show neuron driver loaded
    
  3. Restart Neuron runtime:

    sudo systemctl restart neuron-monitor
    neuron-ls
    

Step 1: Find the Latest AMI

Get the latest JAX DLAMI for Ubuntu 22.04:

aws ec2 describe-images \
  --owners amazon \
  --filters "Name=name,Values=Deep Learning AMI Neuron JAX * (Ubuntu 22.04)*" \
  --query 'Images | sort_by(@, &CreationDate) | [-1].ImageId' \
  --output text

Step 2: Launch Instance

aws ec2 run-instances \
  --image-id ami-xxxxxxxxxxxxxxxxx \
  --instance-type trn1.2xlarge \
  --key-name your-key-pair \
  --security-group-ids sg-xxxxxxxxx \
  --subnet-id subnet-xxxxxxxxx

Step 3: Connect to Instance

ssh -i your-key-pair.pem ubuntu@<instance-public-ip>

Step 4: Activate Environment

source /opt/aws_neuronx_venv_jax/bin/activate

Step 5: Verify Installation

python3 << EOF
import jax
import jax_neuronx

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

# Check Neuron devices
import subprocess
result = subprocess.run(['neuron-ls'], capture_output=True, text=True)
print(result.stdout)
EOF
⚠️ Troubleshooting: Module not found

If you see ModuleNotFoundError: No module named 'jax_neuronx':

  1. Verify virtual environment is activated

  2. Check Python version: python --version (should be 3.10+)

  3. Reinstall: pip install --force-reinstall jax-neuronx

⚠️ Troubleshooting: No Neuron devices found

If neuron-ls shows no devices:

  1. Verify instance type

  2. Check Neuron driver: lsmod | grep neuron

  3. Restart runtime: sudo systemctl restart neuron-monitor

Step 1: Find the Latest AMI

Get the latest JAX DLAMI for Amazon Linux 2023:

aws ec2 describe-images \
  --owners amazon \
  --filters "Name=name,Values=Deep Learning AMI Neuron JAX * (Amazon Linux 2023)*" \
  --query 'Images | sort_by(@, &CreationDate) | [-1].ImageId' \
  --output text

Step 2: Launch Instance

aws ec2 run-instances \
  --image-id ami-xxxxxxxxxxxxxxxxx \
  --instance-type trn1.2xlarge \
  --key-name your-key-pair \
  --security-group-ids sg-xxxxxxxxx \
  --subnet-id subnet-xxxxxxxxx

Step 3: Connect to Instance

ssh -i your-key-pair.pem ec2-user@<instance-public-ip>

Note

Amazon Linux 2023 uses ec2-user instead of ubuntu.

Step 4: Activate Environment

source /opt/aws_neuronx_venv_jax/bin/activate

Step 5: Verify Installation

python3 << EOF
import jax
import jax_neuronx

print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")

# Check Neuron devices
import subprocess
result = subprocess.run(['neuron-ls'], capture_output=True, text=True)
print(result.stdout)
EOF
⚠️ Troubleshooting: Module not found

If you see ModuleNotFoundError: No module named 'jax_neuronx':

  1. Verify virtual environment is activated

  2. Check Python version: python --version (should be 3.10+)

  3. Reinstall: pip install --force-reinstall jax-neuronx

⚠️ Troubleshooting: No Neuron devices found

If neuron-ls shows no devices:

  1. Verify instance type

  2. Check Neuron driver: lsmod | grep neuron

  3. Restart runtime: sudo systemctl restart neuron-monitor

Next Steps#

Now that JAX is installed:

  1. Try a Quick Example:

    import jax
    import jax.numpy as jnp
    
    # Simple operation on Neuron
    x = jnp.array([1.0, 2.0, 3.0])
    y = jnp.array([4.0, 5.0, 6.0])
    result = jax.numpy.multiply(x, y)
    print(result)
    
  2. Read Documentation:

  3. Explore Setup Guide:

Additional Resources#

This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3