This document is relevant for: Inf1, Inf2, Trn1, Trn2
AWS Neuron SDK 2.26.0: JAX support release notes#
Date of release: September 18, 2025
Go back to the AWS Neuron 2.26.0 release notes home
Released versions#
0.6.2.1.0.*
Improvements#
This release introduces support for JAX version
0.6.2.
Known issues#
The
ThreefryRNG algorithm is not completely supported. Use therbgalgorithm instead. This can be configured by setting the following config option:jax.config.update("jax_default_prng_impl", "rbg")For JAX versions older than
0.4.34, caching does not work out of the box. Use this code to enable caching support:import jax import jax_neuronx from jax._src import compilation_cache compilation_cache.set_cache_dir('./cache_directory')
For JAX versions older than
0.4.34, buffer donation does not work out of the box. Add the following snippet to your script to enable it *jax._src.interpreters.mlir._platforms_with_donation.append('neuron')Mesh configurations which use non-connected Neuron cores may crash during execution. You may observe compilation or Neuron runtime errors for such configurations. Device connectivity can be determined by using
neuron-ls --topology.Not all dtypes supported by JAX work on Neuron. Check Data Types for supported data types.
jax.random.randintdoes not produce expected distribution of randint values. Run it on CPU instead.Dynamic loops are not supported for
jax.lax.while_loop. Only static while loops are supported.jax.lax.condis not supported.Host callbacks are not supported. As a result APIs based on callbacks from
jax.debugandjax.experimental.checkifyare not supported.jax.dlpackis not supported.jax.experimental.sparseis not supported.jax.lax.sortonly supports comparators with LE, GE, LT and GT operations.jax.lax.reduce_precisionis not supported.Certain operations (for example, rng weight initialization) might result in slow compilations. Try to run such operations on the CPU backend or by setting the following environment variable:
NEURON_RUN_TRIVIAL_COMPUTATION_ON_CPU=1.Neuron only supports
float8_e4m3andfloat8_e5m2for FP8 dtypes.Complex dtypes (
jnp.complex64andjnp.complex128) are not supported.Variadic reductions are not supported.
Out-of-bounds access for scatter/gather operations can result in runtime errors.
Dot operations on
intdtypes are not supported.lax.DotAlgorithmPresetis not always respected. Dot operations occur in operand dtypes. This is a configurable parameter forjax.lax.dotandjax.lax.dot_general.
Previous release notes#
This document is relevant for: Inf1, Inf2, Trn1, Trn2