This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3

NCC_EVRF031#

Error message: The compiler encountered a scatter out-of-bounds error. The indices created via iota instruction contain values that are beyond the size of the operand dimension.

Erroneous code example:

# size 3 in dimension 0
operand = jnp.zeros((3, 4), dtype=jnp.float32)

# iota generates indices [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
indices = lax.iota(jnp.int32, 10) # ERROR: size 10 > operand dimension 3
indices = indices.reshape(10, 1)

updates = jnp.ones((10, 4), dtype=jnp.float32) # ERROR: 10 updates but operand only has 3 rows

result = lax.scatter(
    operand,
    indices, # ERROR: index values in [0, 10) but operand dimension only allows indices in [0, 3)
    updates,
    lax.ScatterDimensionNumbers(
    update_window_dims=(1,),
    inserted_window_dims=(0,),
    scatter_dims_to_operand_dims=(0,)
    )
)

Ensure that the iota size matches the operand dimension size:

N = 3
D = 4
operand = jnp.zeros((N, D), dtype=jnp.float32)

# FIXED: match iota size to operand dimension
indices = lax.iota(jnp.int32, N) # size N is same as operand dimension
indices = indices.reshape(N, 1)

# FIXED: updates size matches operand dimension
updates = jnp.ones((N, D), dtype=jnp.float32)

result = lax.scatter(
    operand,
    indices, # FIXED: indices now in valid range [0, 3)
    updates,
    lax.ScatterDimensionNumbers(
    update_window_dims=(1,),
    inserted_window_dims=(0,),
    scatter_dims_to_operand_dims=(0,)
    )
)

This document is relevant for: Inf1, Inf2, Trn1, Trn2, Trn3