This document is relevant for: Trn2, Trn3

nki.isa.tensor_reduce#

nki.isa.tensor_reduce(dst, op, data, axis, negate=False, keepdims=False, name=None)[source]#

Apply a reduction operation to the free axes of an input data tile using Vector Engine.

The reduction operator is specified in the op input field (see Supported Math Operators for NKI ISA for a list of supported reduction operators). nisa.tensor_reduce supports two types of reduction operators: 1) bitvec operators (e.g., bitwise_and, bitwise_or) and 2) arithmetic operators (e.g., add, subtract, multiply).

The reduction axes are specified in the axis field as an int or list of ints indicating which dimensions to reduce. The reduction axes must be the last contiguous free dimension(s) of the tile, ending at the final dimension. Axis 0 (partition axis) cannot be reduced.

For example, given a 4D tile (P, D1, D2, D3):

  • axis=(3,) reduces only D3

  • axis=(2, 3) reduces D2 and D3

  • axis=(1, 2, 3) reduces D1, D2, and D3

When the reduction op is an arithmetic operator, the instruction can also multiply the output reduction results by -1.0 before writing into the output tile, at no additional performance cost. This behavior is controlled by the negate input field.

Memory types.

Both the input data and dst tiles can be in SBUF or PSUM.

Data types.

For bitvec operators, the input/output data types must be integer types and Vector Engine treats all input elements as bit patterns without any data type casting. For arithmetic operators, the input/output data types can be any supported NKI data types, but the engine automatically casts input data types to float32 and performs the reduction operation in float32 math. The float32 reduction results are cast to the data type of dst.

Layout.

nisa.tensor_reduce only supports free axes reduction. Therefore, the partition dimension of the input data is considered the parallel compute dimension. To perform a partition axis reduction, we can either:

  1. invoke a nisa.nc_transpose instruction on the input tile and then this nisa.tensor_reduce on the transposed tile, or

  2. invoke nisa.nc_matmul instructions to multiply a nl.ones([128, 1], dtype=data.dtype) tile as a stationary tensor with the input tile as a moving tensor. See more discussion on Tensor Engine alternative usage in Trainium architecture guide.

Tile size.

The partition dimension size of input data and output dst tiles must be the same and must not exceed 128. The number of elements per partition of data must not exceed the physical size of each SBUF partition. The number of elements per partition in dst must be consistent with the axis field. For example, if axis indicates all free dimensions of data are reduced, the number of elements per partition in dst must be 1.

Parameters:
  • dst – output tile of the reduction result

  • op – the reduction operator (see Supported Math Operators for NKI ISA for supported reduction operators)

  • data – the input tile to be reduced

  • axis – int or tuple/list of ints. The axis (or axes) along which to reduce; must be the last contiguous free dimension(s) ending at the final dim. For example, for a 4D tile (P, D1, D2, D3): valid values are (3,), (2, 3), or (1, 2, 3). Axis 0 (partition dim) cannot be reduced.

  • negate – if True, reduction result is multiplied by -1.0; only applicable when op is an arithmetic operator

  • keepdims – If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.

This document is relevant for: Trn2, Trn3