This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3
Install JAX via Deep Learning AMI#
Install JAX with Neuron support using pre-configured AWS Deep Learning AMIs.
⏱️ Estimated time: 5 minutes
Note
Want to read about Neuron’s Deep Learning machine images (DLAMIs) before diving in? Check out the Neuron DLAMI User Guide.
Prerequisites#
Requirement |
Details |
|---|---|
Instance Type |
Inf2, Trn1, Trn2, or Trn3 |
AWS Account |
With EC2 permissions |
SSH Key Pair |
For instance access |
AWS CLI |
Configured with credentials (optional) |
Installation Steps#
Step 1: Find the Latest AMI
Get the latest JAX DLAMI for Ubuntu 24.04:
aws ec2 describe-images \
--owners amazon \
--filters "Name=name,Values=Deep Learning AMI Neuron JAX * (Ubuntu 24.04)*" \
--query 'Images | sort_by(@, &CreationDate) | [-1].ImageId' \
--output text
Step 2: Launch Instance
Launch a Trn1 or Inf2 instance with the AMI:
aws ec2 run-instances \
--image-id ami-xxxxxxxxxxxxxxxxx \
--instance-type trn1.2xlarge \
--key-name your-key-pair \
--security-group-ids sg-xxxxxxxxx \
--subnet-id subnet-xxxxxxxxx
Replace:
ami-xxxxxxxxxxxxxxxxxwith AMI ID from Step 1your-key-pairwith your SSH key pair namesg-xxxxxxxxxwith your security group IDsubnet-xxxxxxxxxwith your subnet ID
Step 3: Connect to Instance
ssh -i your-key-pair.pem ubuntu@<instance-public-ip>
Step 4: Activate Environment
The DLAMI includes a pre-configured virtual environment:
source /opt/aws_neuronx_venv_jax/bin/activate
Step 5: Verify Installation
python3 << EOF
import jax
import jax_neuronx
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
# Check Neuron devices
import subprocess
result = subprocess.run(['neuron-ls'], capture_output=True, text=True)
print(result.stdout)
EOF
Expected output:
JAX version: 0.7.0
Devices: [NeuronDevice(id=0), NeuronDevice(id=1)]
+--------+--------+--------+-----------+
| DEVICE | CORES | MEMORY | CONNECTED |
+--------+--------+--------+-----------+
| 0 | 2 | 32 GB | Yes |
| 1 | 2 | 32 GB | Yes |
+--------+--------+--------+-----------+
⚠️ Troubleshooting: Module not found
If you see ModuleNotFoundError: No module named 'jax_neuronx':
Verify virtual environment is activated:
which python # Should show: /opt/aws_neuronx_venv_jax/bin/python
Check Python version:
python --version # Should be 3.10 or higher
Reinstall jax-neuronx:
pip install --force-reinstall jax-neuronx
⚠️ Troubleshooting: No Neuron devices found
If neuron-ls shows no devices:
Verify instance type:
curl http://169.254.169.254/latest/meta-data/instance-type # Should show trn1.*, trn2.*, trn3.*, or inf2.*
Check Neuron driver:
lsmod | grep neuron # Should show neuron driver loaded
Restart Neuron runtime:
sudo systemctl restart neuron-monitor neuron-ls
Step 1: Find the Latest AMI
Get the latest JAX DLAMI for Ubuntu 22.04:
aws ec2 describe-images \
--owners amazon \
--filters "Name=name,Values=Deep Learning AMI Neuron JAX * (Ubuntu 22.04)*" \
--query 'Images | sort_by(@, &CreationDate) | [-1].ImageId' \
--output text
Step 2: Launch Instance
aws ec2 run-instances \
--image-id ami-xxxxxxxxxxxxxxxxx \
--instance-type trn1.2xlarge \
--key-name your-key-pair \
--security-group-ids sg-xxxxxxxxx \
--subnet-id subnet-xxxxxxxxx
Step 3: Connect to Instance
ssh -i your-key-pair.pem ubuntu@<instance-public-ip>
Step 4: Activate Environment
source /opt/aws_neuronx_venv_jax/bin/activate
Step 5: Verify Installation
python3 << EOF
import jax
import jax_neuronx
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
# Check Neuron devices
import subprocess
result = subprocess.run(['neuron-ls'], capture_output=True, text=True)
print(result.stdout)
EOF
⚠️ Troubleshooting: Module not found
If you see ModuleNotFoundError: No module named 'jax_neuronx':
Verify virtual environment is activated
Check Python version:
python --version(should be 3.10+)Reinstall:
pip install --force-reinstall jax-neuronx
⚠️ Troubleshooting: No Neuron devices found
If neuron-ls shows no devices:
Verify instance type
Check Neuron driver:
lsmod | grep neuronRestart runtime:
sudo systemctl restart neuron-monitor
Step 1: Find the Latest AMI
Get the latest JAX DLAMI for Amazon Linux 2023:
aws ec2 describe-images \
--owners amazon \
--filters "Name=name,Values=Deep Learning AMI Neuron JAX * (Amazon Linux 2023)*" \
--query 'Images | sort_by(@, &CreationDate) | [-1].ImageId' \
--output text
Step 2: Launch Instance
aws ec2 run-instances \
--image-id ami-xxxxxxxxxxxxxxxxx \
--instance-type trn1.2xlarge \
--key-name your-key-pair \
--security-group-ids sg-xxxxxxxxx \
--subnet-id subnet-xxxxxxxxx
Step 3: Connect to Instance
ssh -i your-key-pair.pem ec2-user@<instance-public-ip>
Note
Amazon Linux 2023 uses ec2-user instead of ubuntu.
Step 4: Activate Environment
source /opt/aws_neuronx_venv_jax/bin/activate
Step 5: Verify Installation
python3 << EOF
import jax
import jax_neuronx
print(f"JAX version: {jax.__version__}")
print(f"Devices: {jax.devices()}")
# Check Neuron devices
import subprocess
result = subprocess.run(['neuron-ls'], capture_output=True, text=True)
print(result.stdout)
EOF
⚠️ Troubleshooting: Module not found
If you see ModuleNotFoundError: No module named 'jax_neuronx':
Verify virtual environment is activated
Check Python version:
python --version(should be 3.10+)Reinstall:
pip install --force-reinstall jax-neuronx
⚠️ Troubleshooting: No Neuron devices found
If neuron-ls shows no devices:
Verify instance type
Check Neuron driver:
lsmod | grep neuronRestart runtime:
sudo systemctl restart neuron-monitor
Next Steps#
Now that JAX is installed:
Try a Quick Example:
import jax import jax.numpy as jnp # Simple operation on Neuron x = jnp.array([1.0, 2.0, 3.0]) y = jnp.array([4.0, 5.0, 6.0]) result = jax.numpy.multiply(x, y) print(result)
Read Documentation:
Explore Setup Guide:
Additional Resources#
Neuron DLAMI User Guide - DLAMI documentation
Neuron Containers - Container-based deployment
Installation Troubleshooting - Common issues and solutions
AWS Neuron SDK Release Notes - Version compatibility information
This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3