This document is relevant for: Inf2, Trn1, Trn2

JAX Neuron plugin Setup#

The JAX Neuron plugin is a set of modularized JAX plugin packages integrating AWS Trainium and Inferentia machine learning accelerators into JAX as pluggable devices. It includes the following Python packages, all hosted on the AWS Neuron pip repository.

  • libneuronxla: A package containing Neuron’s integration into JAX’s runtime PJRT, built using the PJRT C-API plugin mechanism. Installing this package enables using Trainium and Inferentia natively as JAX devices.

  • jax-neuronx: A package containing Neuron-specific JAX features, such as the Neuron NKI JAX interface. It also serves as a meta-package for providing a tested combination of the jax-neuronx, jax, jaxlib, libneuronxla, and neuronx-cc packages. Making proper use of the features provided in jax-neuronx will unleash the full potential of Trainium and Inferentia.

Note

If you are facing a connectivity issue during the model loading process on a Trn1 instance with Ubuntu, that could probably be because of Ubuntu limitations with multiple interfaces. To solve this problem, please follow the steps mentioned here.

Users are highly encouraged to use DLAMI to launch the instances, since DLAMIs come with the required fix.

Launch the Instance
Install Drivers and Tools

Ubuntu

Traceback (most recent call last):
  File "/home/docs/checkouts/readthedocs.org/user_builds/awsdocs-neuron/checkouts/latest/src/helperscripts/n2-helper.py", line 1137, in <module>
    print(n2_manifest.generate_script(args))
  File "/home/docs/checkouts/readthedocs.org/user_builds/awsdocs-neuron/checkouts/latest/src/helperscripts/n2-helper.py", line 138, in generate_script
    str_python = self.set_python_venv(args)
  File "/home/docs/checkouts/readthedocs.org/user_builds/awsdocs-neuron/checkouts/latest/src/helperscripts/n2-helper.py", line 575, in set_python_venv
    packages_supporting_python_versions = self.get_pip_packages_supporting_python_versions(args)
  File "/home/docs/checkouts/readthedocs.org/user_builds/awsdocs-neuron/checkouts/latest/src/helperscripts/n2-helper.py", line 82, in get_pip_packages_supporting_python_versions
    framework_python_versions = df_framework.loc[
IndexError: index 0 is out of bounds for axis 0 with size 0

Amazon Linux 2023

Traceback (most recent call last):
  File "/home/docs/checkouts/readthedocs.org/user_builds/awsdocs-neuron/checkouts/latest/src/helperscripts/n2-helper.py", line 1137, in <module>
    print(n2_manifest.generate_script(args))
  File "/home/docs/checkouts/readthedocs.org/user_builds/awsdocs-neuron/checkouts/latest/src/helperscripts/n2-helper.py", line 138, in generate_script
    str_python = self.set_python_venv(args)
  File "/home/docs/checkouts/readthedocs.org/user_builds/awsdocs-neuron/checkouts/latest/src/helperscripts/n2-helper.py", line 575, in set_python_venv
    packages_supporting_python_versions = self.get_pip_packages_supporting_python_versions(args)
  File "/home/docs/checkouts/readthedocs.org/user_builds/awsdocs-neuron/checkouts/latest/src/helperscripts/n2-helper.py", line 82, in get_pip_packages_supporting_python_versions
    framework_python_versions = df_framework.loc[
IndexError: index 0 is out of bounds for axis 0 with size 0
Install the JAX Neuron Plugin

We provide two methods for installing the JAX Neuron plugin. The first is to install the jax-neuronx meta-package from the AWS Neuron pip repository. This method provides a production-ready JAX environment where jax-neuronx’s major dependencies, namely jax, jaxlib, libneuronxla, and neuronx-cc, have undergone thorough testing by the AWS Neuron team and will have their versions pinned during installation.

python3 -m pip install jax-neuronx[stable] --extra-index-url=https://pip.repos.neuron.amazonaws.com

The second is to install packages jax, jaxlib, libneuronxla, and neuronx-cc separately, with jax-neuronx being an optional addition. Because libneuronxla supports a broad range of jaxlib versions through the PJRT C-API mechanism, this method provides flexibility when choosing jax and jaxlib versions, enabling JAX users to bring the JAX Neuron plugin into their own JAX environments.

python3 -m pip install jax==0.4.31 jaxlib==0.4.31 jax-neuronx libneuronxla neuronx-cc==2.* --extra-index-url=https://pip.repos.neuron.amazonaws.com

We can now run some simple JAX programs on the Trainium or Inferentia accelerators.

~$ python3 -c 'import jax; print(jax.numpy.multiply(1, 1))'
Platform 'neuron' is experimental and not all JAX functionality may be correctly supported!
.
Compiler status PASS
1

Compatibility between packages jaxlib and libneuronxla can be determined from PJRT C-API version. For more information, see PJRT integration guide.

To determine compatible JAX versions, you can use the libneuronxla.supported_clients API for querying known supported client packages and their versions.

Help on function supported_clients in module libneuronxla.version:

supported_clients()
    Return a description of supported client (jaxlib, torch-xla, etc.) versions,
    as a list of strings formatted as `"<package> <version> (PJRT C-API <c-api version>)"`.
    For example,
    >>> import libneuronxla
    >>> libneuronxla.supported_clients()
    ['jaxlib 0.4.31 (PJRT C-API 0.54)', 'torch_xla 2.2.0 (PJRT C-API 0.35)', 'torch_xla 2.3.0 (PJRT C-API 0.46)']

Note that the list of supported client packages and versions covers known versions only and may be incomplete. More versions could be supported, including Google’s future jaxlib releases, assuming the PJRT C-API stays compatible with the current release of libneuronxla. As a result, we avoid specifying any dependency relationship between libneuronxla and jaxlib. This provides more freedom when coordinating jax and libneuronxla installations.

This document is relevant for: Inf2, Trn1, Trn2