This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3

Install JAX for Neuron#

Install JAX with AWS Neuron support for training and inference on Inferentia and Trainium instances.

Supported Instances: Inf2, Trn1, Trn2, Trn3

JAX Version: 0.7+ with Neuron PJRT plugin

Beta Release

JAX NeuronX is currently in beta. Some JAX functionality may not be fully supported. We welcome your feedback and contributions.

Choose Installation Method#

🚀 AWS Deep Learning AMI

Recommended for most users

Pre-configured environment with all dependencies

✅ All dependencies included

✅ Tested configurations

✅ Multiple Python versions

⏱️ Setup time: ~5 minutes

🐳 Deep Learning Container

For containerized deployments

Pre-configured Docker images from AWS ECR

✅ Docker-based isolation

✅ Training and inference images

✅ Training images available

⏱️ Setup time: ~10 minutes

🔧 Manual Installation

For custom environments

Install on existing systems or custom setups

✅ Existing system integration

✅ Custom Python versions

✅ Full control over dependencies

⏱️ Setup time: ~15 minutes

Prerequisites#

Before installing, ensure you have:

Requirement

Details

Instance Type

Inf2, Trn1, Trn2, or Trn3 instance

Operating System

Ubuntu 24.04, Ubuntu 22.04, or Amazon Linux 2023

Python Version

Python 3.10, 3.11, or 3.12

AWS Account

With EC2 launch permissions

SSH Access

Key pair for instance connection

What You’ll Get#

After installation, you’ll have:

  • JAX 0.7+ with Neuron PJRT plugin

  • jax-neuronx package for Neuron-specific features

  • libneuronxla PJRT plugin for native JAX device integration

  • neuronx-cc compiler for model optimization

  • Neuron Runtime for model execution

Version Information#

Component

Version

JAX

0.7.0+

jax-neuronx

0.7.0+

libneuronxla

latest

neuronx-cc

2.15.0+

Python

3.10, 3.11, 3.12

Next Steps#

After installation:

  1. Verify Installation: Run verification commands in the installation guide

  2. Read the Guide: JAX NeuronX plugin Setup

  3. Explore JAX on Neuron: JAX Support on Neuron

  4. API Reference: API Reference Guide for JAX Neuronx

This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3