JAX Setup#

This guide provides step-by-step instructions for installing and configuring JAX with the NeuronX plugin on AWS Trainium and Inferentia instances. JAX NeuronX enables high-performance machine learning workloads by integrating JAX with AWS’s custom ML accelerators.

For more installation and deployment options, see JAX NeuronX plugin Setup.

Note

This setup guide is relevant for Inf2 & Trn1 / Trn1n / Trn2 instances.

JAX setup on Ubuntu 22#

Ubuntu 22 (Ubuntu22 AMI)