This document is relevant for: Inf2
, Trn1
, Trn2
nki.isa.tensor_reduce#
- nki.isa.tensor_reduce(op, data, axis, *, mask=None, dtype=None, negate=False, keepdims=False, **kwargs)[source]#
Apply a reduction operation to the free axes of an input
data
tile using Vector Engine.The reduction operator is specified in the
op
input field (see Supported Math Operators for NKI ISA for a list of supported reduction operators). There are two types of reduction operators: 1) bitvec operators (e.g., bitwise_and, bitwise_or) and 2) arithmetic operators (e.g., add, subtract, multiply). For bitvec operators, the input/output data types must be integer types and Vector Engine treats all input elements as bit patterns without any data type casting. For arithmetic operators, there is no restriction on the input/output data types, but the engine automatically casts input data types to float32 and performs the reduction operation in float32 math. The float32 reduction results are cast to the target data type specified in thedtype
field before written into the output tile. If thedtype
field is not specified, it is default to be the same as input tile data type.When the reduction
op
is an arithmetic operator, the instruction can also multiply the output reduction results by-1.0
before writing into the output tile, at no additional performance cost. This behavior is controlled by thenegate
input field.The reduction axes are specified in the
axis
field using a list of integer(s) to indicate axis indices. The reduction axes can contain up to four free axes and must start at the most minor free axis. Since axis 0 is the partition axis in a tile, the reduction axes must contain axis 1 (most-minor). In addition, the reduction axes must be consecutive: e.g., [1, 2, 3, 4] is a legalaxis
field, but [1, 3, 4] is not.Since this instruction only supports free axes reduction, the output tile must have the same partition axis size as the input
data
tile. To perform a partition axis reduction, we can either:invoke a
nki.isa.nc_transpose
instruction on the input tile and then thisreduce
instruction to the transposed tile, orinvoke
nki.isa.nc_matmul
instructions to multiply anki.language.ones([128, 1], dtype=data.dtype)
vector with the input tile.
Estimated instruction cost:
Cost (Vector Engine Cycles)
Condition
N/2
both input and output data types are
bfloat16
and the reduction operator is add or maximumN
otherwise
where,
N
is the number of elements per partition indata
.MIN_II
is the minimum instruction initiation interval for small input tiles.MIN_II
is roughly 64 engine cycles.
- Parameters:
op – the reduction operator (see Supported Math Operators for NKI ISA for supported reduction operators)
data – the input tile to be reduced
axis – int or tuple/list of ints. The axis (or axes) along which to operate; must be free dimensions, not partition dimension (0); can only be the last contiguous dim(s) of the tile:
[1], [1,2], [1,2,3], [1,2,3,4]
mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)
dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tile.
negate – if True, reduction result is multiplied by
-1.0
; only applicable when op is an arithmetic operatorkeepdims – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
- Returns:
output tile of the reduction result
Example:
import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl import numpy as np ... ################################################################## # Example 1: reduce add tile a of shape (128, 512) # in the free dimension and return # reduction result in tile b of shape (128, 1) ################################################################## i_p_a = nl.arange(128)[:, None] i_f_a = nl.arange(512)[None, :] b = nisa.tensor_reduce(np.add, a[i_p_a, i_f_a], axis=[1])
This document is relevant for: Inf2
, Trn1
, Trn2