This document is relevant for: Inf2, Trn1, Trn2
nki.language.gather_flattened#
- nki.language.gather_flattened(data, indices, *, mask=None, dtype=None, **kwargs)[source]#
Gather elements from
dataaccording to theindices.This instruction gathers elements from the
datatensor using integer indices provided in theindicestensor. For each element in theindicestensor, it retrieves the corresponding value from thedatatensor using the index value to select from the free dimension ofdata. The gather instruction effectively performs up to 128 parallel gather operations, with each operation using the corresponding partition ofdataandindices.The output tensor has the same shape as the
indicestensor, with each output element containing the value fromdataat the position specified by the corresponding index. Out of bounds indices will return garbage values.Both
dataandindicesmust be 2-, 3-, or 4-dimensional. Theindicestensor must contain uint32 values.For indexing purposes, all free dimensions are flattened and indexed as the same “row”. Consider this example:
data = [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]] indices = [[[0, 1], [1, 3]], [[3, 3], [1, 0]]] nl.gather_flattened(data, indices) produces this result: [[[1., 2.], [2., 4.]], [[8., 8.], [6., 5.]]]
With the exception of handling out-of-bounds indices, this behavior is equivalent to:
indices_flattened = indices.reshape(indices.shape[0], -1) data_flattened = data.reshape(data.shape[0], -1) result = np.take_along_axis(data_flattened, indices_flattened, axis=-1) result.reshape(indices.shape)
((Similar to torch.gather_flattened))
- Parameters:
data – the source tensor to gather values from
indices – tensor containing uint32 indices to gather across the flattened free dimension.
mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)
dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tile.
- Returns:
a tensor with the same shape as indices containing gathered values from data
Example:
import neuronxcc.nki.language as nl from neuronxcc.nki.typing import tensor ################################################################## # Example 1: Gather values from a tensor using indices ################################################################## # Create source tensor N = 32 M = 64 data = nl.rand((N, M), dtype=nl.float32) # Create indices tensor - gather every 5th element indices = nl.zeros((N, 10), dtype=nl.uint32) for i in nl.static_range(N): for j in nl.static_range(10): indices[i, j] = j * 5 # Gather values from data according to indices result = nl.gather_flattened(data=data, indices=indices)
This document is relevant for: Inf2, Trn1, Trn2