This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3
Install JAX via Deep Learning AMI#
Install JAX with Neuron support using pre-configured AWS Deep Learning AMIs.
⏱️ Estimated time: 5 minutes
Note
Want to read about Neuron’s Deep Learning machine images (DLAMIs) before diving in? Check out the Neuron DLAMI User Guide.
Prerequisites#
Requirement |
Details |
|---|---|
Instance Type |
Inf2, Trn1, Trn2, or Trn3 |
AWS Account |
With EC2 permissions |
SSH Key Pair |
For instance access |
AWS CLI |
Configured with credentials (optional) |
Installation Steps#
Step 1: Find the Latest AMI
Get the latest JAX DLAMI for Ubuntu 24.04:
aws ec2 describe-images \
--owners amazon \
--filters "Name=name,Values=Deep Learning AMI Neuron JAX * (Ubuntu 24.04)*" \
--query 'Images | sort_by(@, &CreationDate) | [-1].ImageId' \
--output text
Step 2: Launch Instance
Launch a Trn1 or Inf2 instance with the AMI:
aws ec2 run-instances \
--image-id ami-xxxxxxxxxxxxxxxxx \
--instance-type trn1.2xlarge \
--key-name your-key-pair \
--security-group-ids sg-xxxxxxxxx \
--subnet-id subnet-xxxxxxxxx
Replace:
ami-xxxxxxxxxxxxxxxxxwith AMI ID from Step 1your-key-pairwith your SSH key pair namesg-xxxxxxxxxwith your security group IDsubnet-xxxxxxxxxwith your subnet ID
Step 3: Connect to Instance
ssh -i your-key-pair.pem ubuntu@<instance-public-ip>
Step 4: Activate Environment
The DLAMI includes a pre-configured virtual environment:
source /opt/aws_neuronx_venv_jax/bin/activate
Step 5: Verify Installation
python3 << EOF
import jax
import jax_neuronx
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
# Check Neuron devices
import subprocess
result = subprocess.run(['neuron-ls'], capture_output=True, text=True)
print(result.stdout)
EOF
Expected output:
JAX version: 0.7.0
Devices: [NeuronDevice(id=0), NeuronDevice(id=1)]
+--------+--------+--------+-----------+
| DEVICE | CORES | MEMORY | CONNECTED |
+--------+--------+--------+-----------+
| 0 | 2 | 32 GB | Yes |
| 1 | 2 | 32 GB | Yes |
+--------+--------+--------+-----------+
⚠️ Troubleshooting: Module not found
If you see ModuleNotFoundError: No module named 'jax_neuronx':
Verify virtual environment is activated:
which python # Should show: /opt/aws_neuronx_venv_jax/bin/python
Check Python version:
python --version # Should be 3.11 or higher
Reinstall jax-neuronx:
pip install --force-reinstall jax-neuronx
⚠️ Troubleshooting: No Neuron devices found
If neuron-ls shows no devices:
Verify instance type:
curl http://169.254.169.254/latest/meta-data/instance-type # Should show trn1.*, trn2.*, trn3.*, or inf2.*
Check Neuron driver:
lsmod | grep neuron # Should show neuron driver loaded
Restart Neuron runtime:
sudo systemctl restart neuron-monitor neuron-ls
Step 1: Find the Latest AMI
Important
Ubuntu 22.04 has reached end-of-support on Neuron. Neuron no longer provides Ubuntu 22.04 DLAMIs or container images. New deployments should use Ubuntu 24.04. See Neuron no longer includes Ubuntu 22.04 DLAMIs and DLCs starting this release.
Get the latest JAX DLAMI for Ubuntu 22.04:
aws ec2 describe-images \
--owners amazon \
--filters "Name=name,Values=Deep Learning AMI Neuron JAX * (Ubuntu 22.04)*" \
--query 'Images | sort_by(@, &CreationDate) | [-1].ImageId' \
--output text
Step 2: Launch Instance
aws ec2 run-instances \
--image-id ami-xxxxxxxxxxxxxxxxx \
--instance-type trn1.2xlarge \
--key-name your-key-pair \
--security-group-ids sg-xxxxxxxxx \
--subnet-id subnet-xxxxxxxxx
Step 3: Connect to Instance
ssh -i your-key-pair.pem ubuntu@<instance-public-ip>
Step 4: Activate Environment
source /opt/aws_neuronx_venv_jax/bin/activate
Step 5: Verify Installation
python3 << EOF
import jax
import jax_neuronx
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
# Check Neuron devices
import subprocess
result = subprocess.run(['neuron-ls'], capture_output=True, text=True)
print(result.stdout)
EOF
⚠️ Troubleshooting: Module not found
If you see ModuleNotFoundError: No module named 'jax_neuronx':
Verify virtual environment is activated
Check Python version:
python --version(should be 3.11+)Reinstall:
pip install --force-reinstall jax-neuronx
⚠️ Troubleshooting: No Neuron devices found
If neuron-ls shows no devices:
Verify instance type
Check Neuron driver:
lsmod | grep neuronRestart runtime:
sudo systemctl restart neuron-monitor
Step 1: Find the Latest AMI
Get the latest JAX DLAMI for Amazon Linux 2023:
aws ec2 describe-images \
--owners amazon \
--filters "Name=name,Values=Deep Learning AMI Neuron JAX * (Amazon Linux 2023)*" \
--query 'Images | sort_by(@, &CreationDate) | [-1].ImageId' \
--output text
Step 2: Launch Instance
aws ec2 run-instances \
--image-id ami-xxxxxxxxxxxxxxxxx \
--instance-type trn1.2xlarge \
--key-name your-key-pair \
--security-group-ids sg-xxxxxxxxx \
--subnet-id subnet-xxxxxxxxx
Step 3: Connect to Instance
ssh -i your-key-pair.pem ec2-user@<instance-public-ip>
Note
Amazon Linux 2023 uses ec2-user instead of ubuntu.
Step 4: Activate Environment
source /opt/aws_neuronx_venv_jax/bin/activate
Step 5: Verify Installation
python3 << EOF
import jax
import jax_neuronx
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
# Check Neuron devices
import subprocess
result = subprocess.run(['neuron-ls'], capture_output=True, text=True)
print(result.stdout)
EOF
⚠️ Troubleshooting: Module not found
If you see ModuleNotFoundError: No module named 'jax_neuronx':
Verify virtual environment is activated
Check Python version:
python --version(should be 3.11+)Reinstall:
pip install --force-reinstall jax-neuronx
⚠️ Troubleshooting: No Neuron devices found
If neuron-ls shows no devices:
Verify instance type
Check Neuron driver:
lsmod | grep neuronRestart runtime:
sudo systemctl restart neuron-monitor
Next Steps#
Now that JAX is installed:
Try a Quick Example:
import jax import jax.numpy as jnp # Simple operation on Neuron x = jnp.array([1.0, 2.0, 3.0]) y = jnp.array([4.0, 5.0, 6.0]) result = jax.numpy.multiply(x, y) print(result)
Read Documentation:
Explore Setup Guide:
Additional Resources#
Neuron DLAMI User Guide - DLAMI documentation
Deploy on AWS - Container-based deployment
Installation Troubleshooting - Common issues and solutions
AWS Neuron SDK Release Notes - Version compatibility information
This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3