This document is relevant for: Inf2
, Trn1
, Trn1n
nki.isa.tensor_tensor_scan#
- nki.isa.tensor_tensor_scan(data0, data1, initial, op0, op1, reverse0=False, reverse1=False, dtype=None, mask=None, **kwargs)[source]#
Perform a scan operation of two input tiles using Vector Engine.
Mathematically, the tensor_tensor_scan instruction on Vector Engine performs the following computation per partition:
# Let's assume we work with numpy, and data0 and data1 are 2D (with shape[0] being the partition axis) import numpy as np result = np.ndarray(data0.shape, dtype=data0.dtype) result[:, 0] = op1(op0(data0[:. 0], initial), data1[:, 0]) for i in range(1, data0.shape[1]): result[:, i] = op1(op0(data0[:, i], result[:, i-1]), data1[:, i])
The two input tiles (
data0
anddata1
) must have the same partition axis size and the same number of elements per partition. The third inputinitial
can either be a float32 compile-time scalar constant that will be broadcasted in the partition axis ofdata0
/data1
, or a tile with the same partition axis size asdata0
/data1
and one element per partition. All input and output tiles can be in either SBUF or PSUM.The scan operation supported by this API has two programmable math operators in
op0
andop1
fields. Bothop0
andop1
can be any binary arithmetic operator supported by NKI (see Supported Math Operators for details). We can optionally reverse the input operands ofop0
by settingreverse0
to True (orop1
by settingreverse1
). Reversing operands is useful for non-commutative operators, such as subtract.Input/output data types can be any supported NKI data type (see Supported Data Types), but the engine automatically casts input data types to float32 and performs the computation in float32 math. The float32 results are cast to the target data type specified in the
dtype
field before written into the output tile. If thedtype
field is not specified, it is default to be the same as the data type ofdata0
ordata1
, whichever has the highest precision.Estimated instruction cost:
2N
Vector Engine cycles, whereN
is the number of elements per partition indata0
/data1
.- Parameters:
data0 – lhs input operand of the scan operation
data1 – rhs input operand of the scan operation
initial – starting state of the scan; can be a SBUF/PSUM tile with 1 element/partition or a scalar compile-time constant
op0 – a binary arithmetic math operator (see Supported Math Operators for supported operators)
op1 – a binary arithmetic math operator (see Supported Math Operators for supported operators)
reverse0 – reverse ordering of inputs to
op0
; if false,data0
is the lhs ofop0
; if true,data0
is the rhs ofop0
reverse1 – reverse ordering of inputs to
op1
; if false,data1
is the rhs ofop1
; if true,data1
is the lhs ofop1
mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)
dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see NKI Type Promotion for more information);
- Returns:
an output tile of the scan operation
Example:
################################################################## # Example 1: scan two tiles, a and b, of the same # shape (128, 1024) using multiply/add and get # the scan result in tile c ################################################################## i_p = nl.arange(128)[:, None] i_f = nl.arange(1024)[None, :] i_f_tile0 = nl.arange(512)[None, :] c = nl.ndarray(shape=(128, 1024), dtype=np.float32) c[i_p, i_f_tile0] = nisa.tensor_tensor_scan(a[i_p, i_f_tile0], b[i_p, i_f_tile0], 0, np.multiply, np.add)
This document is relevant for: Inf2
, Trn1
, Trn1n