This document is relevant for: Inf2, Trn1, Trn2, Trn3

JAX Support on Neuron#

JAX running on Neuron unlocks high-performance and cost-effective deep learning acceleration on AWS Trainium-based and AWS Inferentia-based Amazon EC2 instances.

The JAX NeuronX plugin is a set of modularized JAX plugin packages that integrate AWS Trainium and Inferentia machine learning accelerators into JAX as pluggable devices using the PJRT (Plugin Runtime) mechanism. This enables native JAX device support for Neuron accelerators with minimal code changes.

JAX NeuronX includes the following key components:

  • libneuronxla: Neuron’s integration into JAX’s runtime PJRT, built using the PJRT C-API plugin mechanism. Installing this package enables using Trainium and Inferentia natively as JAX devices.

  • jax-neuronx: A package containing Neuron-specific JAX features, such as the Neuron NKI JAX interface. It also serves as a meta-package for providing a tested combination of jax-neuronx, jax, jaxlib, libneuronxla, and neuronx-cc packages.

Key capabilities of JAX NeuronX include:

  • Native JAX device integration: Seamless integration with JAX through the PJRT C-API plugin mechanism

  • Flexible installation: Choose between a production-ready meta-package or custom package combinations

  • NKI support: Access to Neuron Kernel Interface (NKI) through the JAX interface for custom kernel development

  • Broad compatibility: Support for multiple JAX and jaxlib versions through the PJRT C-API mechanism

  • Training and inference: Full support for both training and inference workloads on Trainium and Inferentia instances

Beta Release

JAX NeuronX is currently in beta. Some JAX functionality may not be fully supported. We welcome your feedback and contributions.

JAX NeuronX Component Release Notes

Review the JAX NeuronX release notes for all versions of the Neuron SDK.

Setup Guide

Install and configure JAX NeuronX for Trn1, Trn2, and Inf2 instances

API Reference Guide

Comprehensive API reference for JAX NeuronX features and environment variables

Known Issues

Review known issues and limitations in the current JAX NeuronX release

Neuron Kernel Interface (NKI)

Learn about NKI for custom kernel development with JAX

This document is relevant for: Inf2, Trn1, Trn2, Trn3