JAX Setup#
This guide provides step-by-step instructions for installing and configuring JAX with the NeuronX plugin on AWS Trainium and Inferentia instances. JAX NeuronX enables high-performance machine learning workloads by integrating JAX with AWS’s custom ML accelerators.
For more installation and deployment options, see JAX NeuronX plugin Setup.
Note
This setup guide is relevant for Inf2
& Trn1
/ Trn1n
/ Trn2
instances.