This document is relevant for: Trn1, Trn2, Trn3
nki.language.gather_flattened#
- nki.language.gather_flattened(data, indices, axis=0, dtype=None)[source]#
Gather elements from data tensor using indices after flattening.
This instruction gathers elements from the data tensor using integer indices provided in the indices tensor. For each element in the indices tensor, it retrieves the corresponding value from the data tensor using the index value to select from the free dimension of data.
Warning
This API is experimental and may change in future releases.
- Parameters:
data – input tensor to gather from.
indices – indices to gather.
axis – axis along which to gather.
dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tile.
- Returns:
gathered tensor.
Examples:
import nki.language as nl # nki.language.gather_flattened -- gather elements by index data = nl.load(data_tensor[0:128, 0:512]) indices = nl.load(indices_tensor[0:128, 0:512]) result = nl.gather_flattened(data, indices) nl.store(actual_tensor[0:128, 0:512], result)
This document is relevant for: Trn1, Trn2, Trn3