This document is relevant for: Inf2, Trn1, Trn2
JAX NeuronX Known Issues#
ThreefryRNG algorithm is not completely supported. Userbgalgorithm instead. This can be configured by setting the following config optionjax.config.update("jax_default_prng_impl", "rbg")For JAX versions older than
0.4.34, caching does not work out of the box. Use the following to enable caching support,import jax import jax_neuronx from jax._src import compilation_cache compilation_cache.set_cache_dir('./cache_directory')
For JAX versions older than
0.4.34, Buffer donation does not work out of the box. Add the following snippet to your script to enable it -jax._src.interpreters.mlir._platforms_with_donation.append('neuron')jax.random.randintdoes not produce expected distribution of randint values. Run it on CPU instead.Dynamic loops are not supported for
jax.lax.while_loop. Only static while loops are supported.jax.lax.condis not supported.Host callbacks are not supported. As a result APIs based on callbacks from
jax.debugandjax.experimental.checkifyare not supported.Mesh configurations which use non-connected Neuron cores might crash during execution. You might observe compilation or Neuron runtime errors for such configurations. Device connectivity can be determined by using
neuron-ls --topology.Not all dtypes supported by JAX work on Neuron. Check Data Types for supported data types.
jax.dlpackis not supported.jax.experimental.sparseis not supported.jax.lax.sortonly supports comparators with LE, GE, LT and GT operations.jax.lax.reduce_precisionis not supported.Certain operations (for example, rng weight initialization) might result in slow compilations. Try to run such operations on the CPU backend or by setting the following environment variable
NEURON_RUN_TRIVIAL_COMPUTATION_ON_CPU=1.Neuron only supports
float8_e4m3andfloat8_e5m2for FP8 dtypes.Complex dtypes (
jnp.complex64andjnp.complex128) are not supported.Variadic reductions are not supported.
Out of bound access for scatter/gather operations can result in runtime errors.
Dot operations on int dtypes are not supported.
lax.DotAlgorithmPresetis not always respected. Dot operations occur in operand dtypes. This is a configurable parameter forjax.lax.dotandjax.lax.dot_general.
This document is relevant for: Inf2, Trn1, Trn2