This document is relevant for: Trn2, Trn3

nki.isa.dma_transpose#

nki.isa.dma_transpose(dst, src, axes=None, priority=None, dge_mode=dge_mode.unknown, oob_mode=oob_mode.error, name=None)[source]#

Perform a transpose on input src using DMA Engine.

The permutation of transpose follow the rules described below:

  1. For 2-d input tile, the permutation will be [1, 0]

  2. For 3-d input tile, the permutation will be [2, 1, 0]

  3. For 4-d input tile, the permutation will be [3, 1, 2, 0]

DMA Direct Transpose Constraints

The only valid dge_mode s are unknown and hwdge. If hwdge, this instruction will be lowered to a Hardware DGE transpose. This has additional restrictions:

  1. src.shape[0] == 16

  2. src.shape[-1] % 128 == 0

  3. src.dtype is 2 bytes

DMA Indirect Transpose Constraints

The only valid dge_mode s are unknown and swdge. This instruction will be lowered to a Software DGE transpose (dma_gather_transpose). This has additional restrictions:

  1. When src is 4D: len(src[1]) or len(src[2]) must be 1

  2. src.shape[-1] <= 128

  3. src.dtype is 2 bytes

  4. src tensor must be on HBM

  5. indices must be 2-d

  6. indices.shape[0] * indices.shape[1] must be >= src.shape[0]

  7. src.shape[0] must be divisible by 16

  8. indices.shape[0] must be in [16, 128] and divisible by 16

  9. When indices.shape[1] > 1: indices.shape[0] must be exactly 128

  10. indices.dtype is np.uint32

  11. indices tensor must be on SBUF

  12. TRN2+ only

Indirect transpose effectively performs the following operation: flat_indices = indices.T.flatten()[:src.shape[0]] gathered = src[flat_indices, :] dst = gathered.T

Indirect transpose example with 1D indices (indices.shape=[128, 1]):

import nki
import nki.isa as nisa
import nki.language as nl

@nki.jit
def gather_transpose_kernel(src_hbm, idx_hbm):
    P, F = 128, 128
    output = nl.ndarray((P, F), dtype=src_hbm.dtype, buffer=nl.shared_hbm)

    idx_sb = nl.load(idx_hbm)

    dst_sb = nl.ndarray((P, F), dtype=src_hbm.dtype, buffer=nl.sbuf)
    nisa.memset(dst=dst_sb, value=0)

    src_ap = src_hbm.ap(
        pattern=[[P, F], [1, P]],
        vector_offset=idx_sb,
        indirect_dim=0,
    )
    nisa.dma_transpose(dst=dst_sb, src=src_ap, axes=(1, 0))

    nisa.dma_copy(dst=output, src=dst_sb)
    return output

Indirect transpose example with 2D indices (indices.shape=[128, N] where N > 1):

@nki.jit
def gather_transpose_2d_kernel(src_hbm, idx_hbm):
    N_COLS = 2  # Number of columns in index tensor
    P = 128  # Partition dimension (max 128)
    F = 128 * N_COLS  # Free dimension: 256

    output = nl.ndarray((P, F), dtype=src_hbm.dtype, buffer=nl.shared_hbm)

    idx_sb = nl.load(idx_hbm)

    dst_sb = nl.ndarray((P, F), dtype=src_hbm.dtype, buffer=nl.sbuf)
    nisa.memset(dst=dst_sb, value=0)

    src_ap = src_hbm.ap(
        pattern=[[P, F], [1, P]],
        vector_offset=idx_sb,
        indirect_dim=0,
    )
    nisa.dma_transpose(dst=dst_sb, src=src_ap, axes=(1, 0))

    nisa.dma_copy(dst=output, src=dst_sb)
    return output

4D indirect transpose example with 2D indices:

@nki.jit
def gather_transpose_4d_kernel(src_hbm, idx_hbm):
    T, d1, d2, d3 = src_hbm.shape
    _, N = idx_hbm.shape
    F = 128 * N

    idx_sb = nl.load(idx_hbm)

    dst_sb = nl.ndarray((d3, d1, d2, F), dtype=src_hbm.dtype, buffer=nl.sbuf)
    nisa.memset(dst=dst_sb, value=0)

    src_ap = src_hbm.ap(
        pattern=[[d1 * d2 * d3, F], [d2 * d3, d1], [d3, d2], [1, d3]],
        vector_offset=idx_sb,
        indirect_dim=0,
    )

    nisa.dma_transpose(dst=dst_sb, src=src_ap, axes=(3, 1, 2, 0))

    output = nl.ndarray((d3, d1, d2, F), dtype=src_hbm.dtype, buffer=nl.shared_hbm)
    nisa.dma_copy(dst=output, src=dst_sb)

    return output
Parameters:
  • dst – the destination of transpose, must be a tile in SBUF.

  • src – the source of transpose, must be a tile in HBM or SBUF. src.dtype == dst.dtype

  • axes – transpose axes where the i-th axis of the transposed tile will correspond to the axes[i] of the source. Supported axes are (1, 0), (2, 1, 0), and (3, 1, 2, 0).

  • priority – (optional): DMA quality-of-service priority level 0-3 where lower is higher priority (NeuronCore-v4+ only). Currently not supported when DGE is turned off (dge_mode=nki.isa.dge_mode.none).

  • dge_mode – (optional) specify which Descriptor Generation Engine (DGE) mode to use for DMA descriptor generation: nki.isa.dge_mode.none (turn off DGE) or nki.isa.dge_mode.swdge (software DGE) or nki.isa.dge_mode.hwdge (hardware DGE) or nki.isa.dge_mode.unknown (by default, let compiler select the best DGE mode). Hardware based DGE is only supported for NeuronCore-v3 or newer. See Trainium2 arch guide for more information.

  • oob_mode

    (optional) Specifies how to handle runtime out-of-bounds (oob) array indices during indirect access operations. Valid modes are:

    • oob_mode.error: (Default) Raises an error when encountering runtime out-of-bounds indices.

    • oob_mode.skip: Silently skips any operations involving out-of-bounds indices. Only valid when src uses indirect indexing.

This document is relevant for: Trn2, Trn3