This document is relevant for: Trn2, Trn3
nki.isa.dma_transpose#
- nki.isa.dma_transpose(dst, src, axes=None, dge_mode=dge_mode.unknown, oob_mode=oob_mode.error, name=None)[source]#
Perform a transpose on input
srcusing DMA Engine.The permutation of transpose follow the rules described below:
For 2-d input tile, the permutation will be [1, 0]
For 3-d input tile, the permutation will be [2, 1, 0]
For 4-d input tile, the permutation will be [3, 1, 2, 0]
DMA Direct Transpose Constraints
The only valid
dge_modes areunknownandhwdge. Ifhwdge, this instruction will be lowered to a Hardware DGE transpose. This has additional restrictions:src.shape[0] == 16src.shape[-1] % 128 == 0src.dtypeis 2 bytes
DMA Indirect Transpose Constraints
The only valid
dge_modes areunknownandswdge. This instruction will be lowered to a Software DGE transpose (dma_gather_transpose). This has additional restrictions:When
srcis 4D:len(src[1])orlen(src[2])must be 1src.shape[-1] <= 128src.dtypeis 2 bytessrctensor must be on HBMindicesmust be 2-dindices.shape[0] * indices.shape[1]must be>=src.shape[0]src.shape[0]must be divisible by 16indices.shape[0]must be in[16, 128]and divisible by 16When
indices.shape[1] > 1:indices.shape[0]must be exactly 128indices.dtypeisnp.uint32indicestensor must be on SBUFTRN2+ only
Indirect transpose effectively performs the following operation:
flat_indices = indices.T.flatten()[:src.shape[0]]gathered = src[flat_indices, :]dst = gathered.TIndirect transpose example (
vector_offsetonsrc):import nki import nki.isa as nisa import nki.language as nl @nki.jit def gather_transpose_kernel(src_hbm, idx_hbm): P, F = 128, 128 output = nl.ndarray((F, P), dtype=src_hbm.dtype, buffer=nl.shared_hbm) idx_sb = nl.load(idx_hbm) dst_sb = nl.ndarray((F, P), dtype=src_hbm.dtype, buffer=nl.sbuf) nisa.memset(dst=dst_sb, value=0) src_ap = src_hbm.ap( pattern=[[F, P], [1, F]], vector_offset=idx_sb, indirect_dim=0, ) nisa.dma_transpose(dst=dst_sb, src=src_ap, axes=(1, 0)) nisa.dma_copy(dst=output, src=dst_sb) return output
- Parameters:
dst – the destination of transpose, must be a tile in SBUF.
src – the source of transpose, must be a tile in HBM or SBUF.
src.dtype == dst.dtypeaxes – transpose axes where the i-th axis of the transposed tile will correspond to the axes[i] of the source. Supported axes are
(1, 0),(2, 1, 0), and(3, 1, 2, 0).dge_mode – (optional) specify which Descriptor Generation Engine (DGE) mode to use for DMA descriptor generation:
nki.isa.dge_mode.none(turn off DGE) ornki.isa.dge_mode.swdge(software DGE) ornki.isa.dge_mode.hwdge(hardware DGE) ornki.isa.dge_mode.unknown(by default, let compiler select the best DGE mode). Hardware based DGE is only supported for NeuronCore-v3 or newer. See Trainium2 arch guide for more information.oob_mode –
(optional) Specifies how to handle runtime out-of-bounds (oob) array indices during indirect access operations. Valid modes are:
oob_mode.error: (Default) Raises an error when encountering runtime out-of-bounds indices.oob_mode.skip: Silently skips any operations involving out-of-bounds indices. Only valid whensrcuses indirect indexing.
This document is relevant for: Trn2, Trn3