This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3
Install JAX for Neuron#
Install JAX with AWS Neuron support for training and inference on Inferentia and Trainium instances.
Supported Instances: Inf2, Trn1, Trn2, Trn3
JAX Version: 0.7+ with Neuron PJRT plugin
Beta Release
JAX NeuronX is currently in beta. Some JAX functionality may not be fully supported. We welcome your feedback and contributions.
Choose Installation Method#
Prerequisites#
Before installing, ensure you have:
Requirement |
Details |
|---|---|
Instance Type |
Inf2, Trn1, Trn2, or Trn3 instance |
Operating System |
Ubuntu 24.04, Ubuntu 22.04, or Amazon Linux 2023 |
Python Version |
Python 3.10, 3.11, or 3.12 |
AWS Account |
With EC2 launch permissions |
SSH Access |
Key pair for instance connection |
What You’ll Get#
After installation, you’ll have:
JAX 0.7+ with Neuron PJRT plugin
jax-neuronx package for Neuron-specific features
libneuronxla PJRT plugin for native JAX device integration
neuronx-cc compiler for model optimization
Neuron Runtime for model execution
Version Information#
Component |
Version |
|---|---|
JAX |
0.7.0+ |
jax-neuronx |
0.7.0+ |
libneuronxla |
latest |
neuronx-cc |
2.15.0+ |
Python |
3.10, 3.11, 3.12 |
Next Steps#
After installation:
Verify Installation: Run verification commands in the installation guide
Read the Guide: JAX NeuronX plugin Setup
Explore JAX on Neuron: JAX Support on Neuron
API Reference: API Reference Guide for JAX Neuronx
This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3