This document is relevant for: Inf2
, Trn1
, Trn2
nki.language.gather_flattened#
- nki.language.gather_flattened(data, indices, *, mask=None, dtype=None, **kwargs)[source]#
Gather elements from
data
according to theindices
.This instruction gathers elements from the
data
tensor using integer indices provided in theindices
tensor. For each element in theindices
tensor, it retrieves the corresponding value from thedata
tensor using the index value to select from the free dimension ofdata
. The gather instruction effectively performs up to 128 parallel gather operations, with each operation using the corresponding partition ofdata
andindices
.The output tensor has the same shape as the
indices
tensor, with each output element containing the value fromdata
at the position specified by the corresponding index. Out of bounds indices will return garbage values.Both
data
andindices
must be 2-, 3-, or 4-dimensional. Theindices
tensor must contain uint32 values.For indexing purposes, all free dimensions are flattened and indexed as the same “row”. Consider this example:
data = [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]] indices = [[[0, 1], [1, 3]], [[3, 3], [1, 0]]] nl.gather_flattened(data, indices) produces this result: [[[1., 2.], [2., 4.]], [[8., 8.], [6., 5.]]]
With the exception of handling out-of-bounds indices, this behavior is equivalent to:
indices_flattened = indices.reshape(indices.shape[0], -1) data_flattened = data.reshape(data.shape[0], -1) result = np.take_along_axis(data_flattened, indices_flattened, axis=-1) result.reshape(indices.shape)
((Similar to torch.gather_flattened))
- Parameters:
data – the source tensor to gather values from
indices – tensor containing uint32 indices to gather across the flattened free dimension.
mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)
dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tile.
- Returns:
a tensor with the same shape as indices containing gathered values from data
Example:
import neuronxcc.nki.language as nl from neuronxcc.nki.typing import tensor ################################################################## # Example 1: Gather values from a tensor using indices ################################################################## # Create source tensor N = 32 M = 64 data = nl.rand((N, M), dtype=nl.float32) # Create indices tensor - gather every 5th element indices = nl.zeros((N, 10), dtype=nl.uint32) for i in nl.static_range(N): for j in nl.static_range(10): indices[i, j] = j * 5 # Gather values from data according to indices result = nl.gather_flattened(data=data, indices=indices)
This document is relevant for: Inf2
, Trn1
, Trn2