This document is relevant for: Inf2, Trn1, Trn2
nki.isa.local_gather#
- nki.isa.local_gather(src_buffer, index, num_elem_per_idx=1, num_valid_indices=None, *, mask=None)[source]#
Gather SBUF data in
src_bufferusingindexon GpSimd Engine.Each of the eight GpSimd cores in GpSimd Engine connects to 16 contiguous SBUF partitions (e.g., core[0] connected to partition[0:16]) and performs gather from the connected 16 SBUF partitions independently in parallel. The indices used for gather on each core should also come from the same 16 connected SBUF partitions.
During execution of the instruction, each GpSimd core reads a 16-partition slice from
index, flattens all indices into a 1D arrayindices_1d(along the partition dimension first). By default with nonum_valid_indicesspecified, each GpSimd core will treat all indices from its corresponding 16-partitionindexslice as valid indices. However, when the number of valid indices per core is not a multiple of 16, users can explicitly specify the valid index count per core innum_valid_indices. Note,num_valid_indicesmust not exceed the total element count in each 16-partitionindexslice (i.e.,num_valid_indices <= index.size / (index.shape[0] / 16)).Next, each GpSimd core uses the flattened
indices_1dindices as partition offsets to gather from the connected 16-partition slice ofsrc_buffer. Optionally, this API also allows gathering of multiple contiguous elements starting at each index to improve gather throughput, as indicated bynum_elem_per_idx. Behavior of out-of-bound index access is undefined.Even though all eight GpSimd cores can gather with completely different indices, a common use case for this API is to make all cores gather with the same set of indices (i.e., partition offsets). In this case, users can generate indices into 16 partitions, replicate them eight times to 128 partitions and then feed them into
local_gather.As an example, if
src_bufferis (128, 512) in shape andindexis (128, 4) in shape, where the partition dimension size is 128,local_gathereffectively performs the following operation:num_gpsimd_cores = 8 num_partitions_per_core = 16 src_buffer = np.random.random_sample([128, 512, 4]).astype(np.float32) * 100 index_per_core = np.random.randint(low=0, high=512, size=(16, 4), dtype=np.uint16) # replicate 8 times for 8 GpSimd cores index = np.tile(index_per_core, (num_gpsimd_cores, 1)) num_elem_per_idx = 4 index_hw = index * num_elem_per_idx num_valid_indices = 64 output_shape = (128, 4, 16, 4) num_active_cores = index.shape[0] / num_partitions_per_core num_valid_indices = num_valid_indices if num_valid_indices \ else index.size / num_active_cores output_np = np.ndarray(shape=(128, num_valid_indices, num_elem_per_idx), dtype=src_buffer.dtype) for i_core in range(num_gpsimd_cores): start_par = i_core * num_partitions_per_core end_par = (i_core + 1) * num_partitions_per_core indices_1d = index[start_par:end_par].flatten(order='F')[0: num_valid_indices] output_np[start_par:end_par, :, :] = np.take( src_buffer[start_par:end_par], indices_1d, axis=1) output_np = output_np.reshape(output_shape)
local_gatherpreserves the input data types fromsrc_bufferin the gather output. Therefore, no data type casting is allowed in this API. The indices inindextile must be uint16 types.This API has three tile size constraints [subject to future relaxation]:
The partition axis size of
src_buffermust match that ofindexand must be a multiple of 16. In other words,src_buffer.shape[0] == index.shape[0] and src_buffer.shape[0] % 16 == 0.The number of contiguous elements to gather per index per partition
num_elem_per_idxmust be one of the following values:[1, 2, 4, 8, 16, 32].The number of indices for gather per core must be less than or equal to 4096.
Estimated instruction cost:
150 + (num_valid_indices * num_elem_per_idx)/CGpSimd Engine cycles, whereCcan be calculated using((28 + t * num_elem_per_idx)/(t * num_elem_per_idx)) / min(4/dtype_size, num_elem_per_idx).dtype_sizeis the size ofsrc_buffer.dtypein bytes. Currently,tis a constant 4, but subject to change in future software implementation.- Parameters:
src_buffer – an input tile for gathering.
index – an input tile with indices used for gathering.
num_elem_per_idx – an optional integer value to read multiple contiguous elements per index per partition; default is 1.
num_valid_indices – an optional integer value to specify the number of valid indices per GpSimd core; default is
index.size / (index.shape[0] / 16).mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)
- Returns:
an output tile of the gathered data
Example:
import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl from neuronxcc.nki.typing import tensor ################################################################## # Example 1: gather src_buffer using index # Gather input: src_buffer_tile with shape (128, 512, 4) # Gather indices: index_tile with shape (128, 4) # We use num_valid_indices indices per core, and read num_elem_per_idx # contiguous elements per partition. ################################################################## src_buffer_tile: tensor[128, 512, 4] = nl.load(src_buffer) index_tile: tensor[128, 4] = nl.load(index) output_tile: tensor[128, 4, 16, 4] = nisa.local_gather( src_buffer_tile, index_tile, num_elem_per_idx, num_valid_indices) nl.store(output, output_tile)
Click
hereto download the full NKI code example with equivalent numpy implementation.
This document is relevant for: Inf2, Trn1, Trn2