This document is relevant for: Inf2
, Trn1
, Trn2
JAX Neuron Known issues#
Threefry
RNG algorithm is not completely supported. Userbg
algorithm instead. This can be configured by setting the following config optionjax.config.update("jax_default_prng_impl", "rbg")
For JAX versions older than
0.4.34
, caching does not work out of the box. Use the following to enable caching support,import jax import jax_neuronx from jax._src import compilation_cache compilation_cache.set_cache_dir('./cache_directory')
For JAX versions older than
0.4.34
, Buffer donation does not work out of the box. Add the following snippet to your script to enable it -jax._src.interpreters.mlir._platforms_with_donation.append('neuron')
jax.random.randint
does not produce expected distribution of randint values. Run it on CPU instead.Dynamic loops are not supported for
jax.lax.while_loop
. Only static while loops are supported.jax.lax.cond
is not supported.Host callbacks are not supported. As a result APIs based on callbacks from
jax.debug
andjax.experimental.checkify
are not supported.Mesh configurations which use non-connected Neuron cores might crash during execution. You might observe compilation or Neuron runtime errors for such configurations. Device connectivity can be determined by using
neuron-ls --topology
.Not all dtypes supported by JAX work on Neuron. Check data types supported in Neuron for supported data types.
jax.dlpack
is not supported.jax.experimental.sparse
is not supported.jax.lax.sort
only supports comparators with LE, GE, LT and GT operations.jax.lax.reduce_precision
is not supported.Certain operations (for example, rng weight initialization) might result in slow compilations. Try to run such operations on CPU.
This document is relevant for: Inf2
, Trn1
, Trn2