This document is relevant for: Inf2, Trn1, Trn2

nki.language.gather_flattened#

nki.language.gather_flattened(data, indices, *, mask=None, dtype=None, **kwargs)[source]#

Gather elements from data according to the indices.

This instruction gathers elements from the data tensor using integer indices provided in the indices tensor. For each element in the indices tensor, it retrieves the corresponding value from the data tensor using the index value to select from the free dimension of data. The gather instruction effectively performs up to 128 parallel gather operations, with each operation using the corresponding partition of data and indices.

The output tensor has the same shape as the indices tensor, with each output element containing the value from data at the position specified by the corresponding index. Out of bounds indices will return garbage values.

Both data and indices must be 2-, 3-, or 4-dimensional. The indices tensor must contain uint32 values.

For indexing purposes, all free dimensions are flattened and indexed as the same “row”. Consider this example:

data =
[[[1., 2.],
 [3., 4.]],
[[5., 6.],
 [7., 8.]]]
indices =
[[[0, 1],
  [1, 3]],
 [[3, 3],
  [1, 0]]]
nl.gather_flattened(data, indices) produces this result:
[[[1., 2.],
  [2., 4.]],
 [[8., 8.],
  [6., 5.]]]

With the exception of handling out-of-bounds indices, this behavior is equivalent to:

indices_flattened = indices.reshape(indices.shape[0], -1)
data_flattened = data.reshape(data.shape[0], -1)
result = np.take_along_axis(data_flattened, indices_flattened, axis=-1)
result.reshape(indices.shape)

((Similar to torch.gather_flattened))

Parameters:
  • data – the source tensor to gather values from

  • indices – tensor containing uint32 indices to gather across the flattened free dimension.

  • mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)

  • dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tile.

Returns:

a tensor with the same shape as indices containing gathered values from data

Example:

import neuronxcc.nki.language as nl
from neuronxcc.nki.typing import tensor

##################################################################
# Example 1: Gather values from a tensor using indices
##################################################################
# Create source tensor
N = 32
M = 64
data = nl.rand((N, M), dtype=nl.float32)

# Create indices tensor - gather every 5th element
indices = nl.zeros((N, 10), dtype=nl.uint32)
for i in nl.static_range(N):
    for j in nl.static_range(10):
        indices[i, j] = j * 5

# Gather values from data according to indices
result = nl.gather_flattened(data=data, indices=indices)

This document is relevant for: Inf2, Trn1, Trn2