This document is relevant for: Inf1, Inf2, Trn1, Trn2
AWS Neuron SDK 2.26.0: JAX support release notes#
Date of release: September 18, 2025
- Go back to the AWS Neuron 2.26.0 release notes home 
Released versions#
- 0.6.2.1.0.*
Improvements#
- This release introduces support for JAX version - 0.6.2.
Known issues#
- The - ThreefryRNG algorithm is not completely supported. Use the- rbgalgorithm instead. This can be configured by setting the following config option:- jax.config.update("jax_default_prng_impl", "rbg")
- For JAX versions older than - 0.4.34, caching does not work out of the box. Use this code to enable caching support:- import jax import jax_neuronx from jax._src import compilation_cache compilation_cache.set_cache_dir('./cache_directory') 
- For JAX versions older than - 0.4.34, buffer donation does not work out of the box. Add the following snippet to your script to enable it *- jax._src.interpreters.mlir._platforms_with_donation.append('neuron')
- Mesh configurations which use non-connected Neuron cores may crash during execution. You may observe compilation or Neuron runtime errors for such configurations. Device connectivity can be determined by using - neuron-ls --topology.
- Not all dtypes supported by JAX work on Neuron. Check Data Types for supported data types. 
- jax.random.randintdoes not produce expected distribution of randint values. Run it on CPU instead.
- Dynamic loops are not supported for - jax.lax.while_loop. Only static while loops are supported.
- jax.lax.condis not supported.
- Host callbacks are not supported. As a result APIs based on callbacks from - jax.debugand- jax.experimental.checkifyare not supported.
- jax.dlpackis not supported.
- jax.experimental.sparseis not supported.
- jax.lax.sortonly supports comparators with LE, GE, LT and GT operations.
- jax.lax.reduce_precisionis not supported.
- Certain operations (for example, rng weight initialization) might result in slow compilations. Try to run such operations on the CPU backend or by setting the following environment variable: - NEURON_RUN_TRIVIAL_COMPUTATION_ON_CPU=1.
- Neuron only supports - float8_e4m3and- float8_e5m2for FP8 dtypes.
- Complex dtypes ( - jnp.complex64and- jnp.complex128) are not supported.
- Variadic reductions are not supported. 
- Out-of-bounds access for scatter/gather operations can result in runtime errors. 
- Dot operations on - intdtypes are not supported.
- lax.DotAlgorithmPresetis not always respected. Dot operations occur in operand dtypes. This is a configurable parameter for- jax.lax.dotand- jax.lax.dot_general.
Previous release notes#
This document is relevant for: Inf1, Inf2, Trn1, Trn2
