This document is relevant for: Inf2, Trn1, Trn2
nki.isa.tensor_tensor_scan#
- nki.isa.tensor_tensor_scan(data0, data1, initial, op0, op1, reverse0=False, reverse1=False, *, dtype=None, mask=None, **kwargs)[source]#
Perform a scan operation of two input tiles using Vector Engine.
Mathematically, the tensor_tensor_scan instruction on Vector Engine performs the following computation per partition:
# Let's assume we work with numpy, and data0 and data1 are 2D (with shape[0] being the partition axis) import numpy as np result = np.ndarray(data0.shape, dtype=data0.dtype) result[:, 0] = op1(op0(data0[:. 0], initial), data1[:, 0]) for i in range(1, data0.shape[1]): result[:, i] = op1(op0(data0[:, i], result[:, i-1]), data1[:, i])
The two input tiles (
data0anddata1) must have the same partition axis size and the same number of elements per partition. The third inputinitialcan either be a float32 compile-time scalar constant that will be broadcasted in the partition axis ofdata0/data1, or a tile with the same partition axis size asdata0/data1and one element per partition.The two input tiles,
data0anddata1cannot both reside in PSUM. The three legal cases are:Both
data1anddata2are in SBUF.data1is in SBUF, whiledata2is in PSUM.data1is in PSUM, whiledata2is in SBUF.
The scan operation supported by this API has two programmable math operators in
op0andop1fields. Bothop0andop1can be any binary arithmetic operator supported by NKI (see Supported Math Operators for NKI ISA for details). We can optionally reverse the input operands ofop0by settingreverse0to True (orop1by settingreverse1). Reversing operands is useful for non-commutative operators, such as subtract.Input/output data types can be any supported NKI data type (see Supported Data Types), but the engine automatically casts input data types to float32 and performs the computation in float32 math. The float32 results are cast to the target data type specified in the
dtypefield before written into the output tile. If thedtypefield is not specified, it is default to be the same as the data type ofdata0ordata1, whichever has the highest precision.Estimated instruction cost:
max(MIN_II, 2N)Vector Engine cycles, whereNis the number of elements per partition indata0/data1.MIN_IIis the minimum instruction initiation interval for small input tiles.MIN_IIis roughly 64 engine cycles.
- Parameters:
data0 – lhs input operand of the scan operation
data1 – rhs input operand of the scan operation
initial – starting state of the scan; can be a SBUF/PSUM tile with 1 element/partition or a scalar compile-time constant
op0 – a binary arithmetic math operator (see Supported Math Operators for NKI ISA for supported operators)
op1 – a binary arithmetic math operator (see Supported Math Operators for NKI ISA for supported operators)
reverse0 – reverse ordering of inputs to
op0; if false,data0is the lhs ofop0; if true,data0is the rhs ofop0reverse1 – reverse ordering of inputs to
op1; if false,data1is the rhs ofop1; if true,data1is the lhs ofop1mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)
dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see NKI Type Promotion for more information);
- Returns:
an output tile of the scan operation
Example:
import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl ################################################################## # Example 1: scan two tiles, a and b, of the same # shape (128, 1024) using multiply/add and get # the scan result in tile c ################################################################## c = nl.ndarray(shape=(128, 1024), dtype=nl.float32) c[:, 0:512] = nisa.tensor_tensor_scan(a[:, 0:512], b[:, 0:512], initial=0, op0=np.multiply, op1=np.add) c[:, 512:1024] = nisa.tensor_tensor_scan(a[:, 512:1024], b[:, 512:1024], initial=c[:, 511], op0=np.multiply, op1=np.add)
This document is relevant for: Inf2, Trn1, Trn2