This document is relevant for: Trn2, Trn3

nki.isa.dma_compute#

nki.isa.dma_compute(dst, srcs, reduce_op, scales=None, unique_indices=True, name=None)[source]#

Perform math operations using compute logic inside DMA engines with element-wise scaling and reduction.

This instruction leverages the compute capabilities within DMA engines to perform scaled element-wise operations followed by reduction across multiple source tensors. The computation follows the pattern: dst = reduce_op(srcs[0] * scales[0], srcs[1] * scales[1], ...), where each source tensor is first multiplied by its corresponding scale factor, then all scaled results are combined using the specified reduction operation. Currently, only nl.add is supported for reduce_op, and all values in scales must be 1.0 (or scales can be None which defaults to all 1.0).

The DMA engines perform all computations in float32 precision internally. Input tensors are automatically cast from their source data types to float32 before computation, and the final float32 result is cast to the output data type in a pipelined fashion.

Read-Modify-Write with vector_offset (scatter and gather).

When one of the source tensors has a vector_offset (indirect indexing), dma_compute performs read-modify-write with two modes:

Scatter RMW: dst(HBM)[indices] = dst(HBM)[indices] + src(SB)
  • dst is in HBM with indirect indexing

  • One source matches dst and has vector_offset

  • The other source is data in SBUF

Gather RMW: dst(SB) = dst(SB) + src(HBM)[indices]
  • dst is in SBUF

  • One source is data in HBM with vector_offset

  • The other source matches dst

Both modes require:
  • Exactly 2 source tensors

  • All scales must be 1.0 (or None)

  • unique_indices must be True (non-unique indices not yet supported)

Memory types.

Both input srcs tensors and output dst tensor can be in HBM or SBUF. Both srcs and dst tensors must have compile-time known addresses (unless using vector_offset for indirect access).

Data types.

All input srcs tensors and the output dst tensor can be any supported NKI data types (see Supported Data Types for more information). The DMA engines automatically cast input data types to float32 before performing the scaled reduction computation. The float32 computation results are then cast to the data type of dst in a pipelined fashion.

Layout.

The computation is performed element-wise across all tensors, with the reduction operation applied across the scaled source tensors at each element position.

Tile size.

The element count of each tensor in srcs and dst must match exactly. The max number of source tensors in srcs is 16.

Parameters:
  • dst – the output tensor to store the computed results

  • srcs – a list of input tensors to be scaled and reduced

  • reduce_op – the reduction operation to apply (currently only nl.add is supported)

  • scales – (optional) a list of scale factors corresponding to each tensor in srcs. Must be all 1.0 if provided. Defaults to None (equivalent to [1.0, 1.0, …]).

  • unique_indices – (optional) Whether scatter indices are unique. Must be True when using vector_offset (non-unique not yet supported). Default: True.

This document is relevant for: Trn2, Trn3