This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3

NCC_ESPP004#

Error message: The compiler encountered a data type that is not supported for code generation.

Erroneous code example:

import numpy as np
import jax.numpy as jnp
import jax
from jax._src import dtypes
from jax._src.lax import lax as lax_internal

# float4_e2m1fn type not supported
dtype = np.dtype(dtypes.float4_e2m1fn)
val = lax_internal._convert_element_type(0, dtype, weak_type=False)

Use a supported data type:

import numpy as np
import jax.numpy as jnp
import jax
from jax._src import dtypes
from jax._src.lax import lax as lax_internal

# float4_e2m1fn type not supported
dtype = jnp.bfloat16
val = lax_internal._convert_element_type(0, dtype, weak_type=False)

More information on supported data types https://awsdocs-neuron.readthedocs-hosted.com/en/latest/about-neuron/arch/neuron-features/data-types.html

This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3