This document is relevant for: Inf2, Trn1, Trn1n

nki.isa.tensor_tensor_scan#

nki.isa.tensor_tensor_scan(data0, data1, initial, op0, op1, reverse0=False, reverse1=False, dtype=None, mask=None, **kwargs)[source]#

Perform a scan operation of two input tiles using Vector Engine.

Mathematically, the tensor_tensor_scan instruction on Vector Engine performs the following computation per partition:

# Let's assume we work with numpy, and data0 and data1 are 2D (with shape[0] being the partition axis)
import numpy as np

result = np.ndarray(data0.shape, dtype=data0.dtype)
result[:, 0] = op1(op0(data0[:. 0], initial), data1[:, 0])

for i in range(1, data0.shape[1]):
    result[:, i] = op1(op0(data0[:, i], result[:, i-1]), data1[:, i])

The two input tiles (data0 and data1) must have the same partition axis size and the same number of elements per partition. The third input initial can either be a float32 compile-time scalar constant that will be broadcasted in the partition axis of data0/data1, or a tile with the same partition axis size as data0/data1 and one element per partition. All input and output tiles can be in either SBUF or PSUM.

The scan operation supported by this API has two programmable math operators in op0 and op1 fields. Both op0 and op1 can be any binary arithmetic operator supported by NKI (see Supported Math Operators for details). We can optionally reverse the input operands of op0 by setting reverse0 to True (or op1 by setting reverse1). Reversing operands is useful for non-commutative operators, such as subtract.

Input/output data types can be any supported NKI data type (see Supported Data Types), but the engine automatically casts input data types to float32 and performs the computation in float32 math. The float32 results are cast to the target data type specified in the dtype field before written into the output tile. If the dtype field is not specified, it is default to be the same as the data type of data0 or data1, whichever has the highest precision.

Estimated instruction cost:

2N Vector Engine cycles, where N is the number of elements per partition in data0/data1.

Parameters:
  • data0 – lhs input operand of the scan operation

  • data1 – rhs input operand of the scan operation

  • initial – starting state of the scan; can be a SBUF/PSUM tile with 1 element/partition or a scalar compile-time constant

  • op0 – a binary arithmetic math operator (see Supported Math Operators for supported operators)

  • op1 – a binary arithmetic math operator (see Supported Math Operators for supported operators)

  • reverse0 – reverse ordering of inputs to op0; if false, data0 is the lhs of op0; if true, data0 is the rhs of op0

  • reverse1 – reverse ordering of inputs to op1; if false, data1 is the rhs of op1; if true, data1 is the lhs of op1

  • mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)

  • dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see NKI Type Promotion for more information);

Returns:

an output tile of the scan operation

Example:

  ##################################################################
  # Example 1: scan two tiles, a and b, of the same
  # shape (128, 1024) using multiply/add and get
  # the scan result in tile c
  ##################################################################
  i_p = nl.arange(128)[:, None]
  i_f = nl.arange(1024)[None, :]
  i_f_tile0 = nl.arange(512)[None, :]

  c = nl.ndarray(shape=(128, 1024), dtype=np.float32)

  c[i_p, i_f_tile0] = nisa.tensor_tensor_scan(a[i_p, i_f_tile0], b[i_p, i_f_tile0], 0,
                                              np.multiply, np.add)

This document is relevant for: Inf2, Trn1, Trn1n