nki.isa.nc_matmul_mx#

nki.isa.nc_matmul_mx(dst, stationary, moving, stationary_scale, moving_scale, tile_position=None, tile_size=None, psum_accumulate_flag=3, name=None)[source]#

Compute matrix multiplication of MXFP8 quantized matrices with integrated dequantization using Tensor Engine.

Note

Available only on NeuronCore-v4 and beyond.

The NeuronCore-v4 Tensor Engine supports matrix multiplication of MXFP8 quantized matrices as defined in the OCP Microscaling standard. This instruction performs matrix multiplication between quantized stationary and moving matrices while applying dequantization scales during computation. The micro-scaling group size is 32 elements along the contraction dimension of both stationary and moving tensors. See Trainium3 arch guide for more detailed discussion.

Tiling Mode.

NeuronCore Tensor Engine is built upon a systolic array with 128 rows and 128 columns of processing elements (PEs). For nc_matmul_mx, Tensor Engine supports only row tiling mode, which allows multiple nc_matmul_mx instructions with a stationary partition dimension size smaller than 128 to run in parallel to improve hardware utilization. Row tiling mode slices the 128 PE rows into 2x 64 row tiles or 4x 32 row tiles.

The row tile size can be set in the tile_size field as a tuple (row_size, column_size), where column_size must be 128. The stationary tile size must not exceed the chosen tile_size.

A given nc_matmul_mx can pick the exact row tile within the 128x128 systolic array by specifying the starting row in tile_position as a tuple (start_row, start_column), where start_column must be 0. The start_row must be a multiple of row_size specified in tile_size and must not exceed 128.

For example, setting tile_position to (64, 0) and tile_size to (64, 128) means using the bottom half of the systolic array.

Note, tile_position and tile_size must both be set to enable tiling mode. If they are not set, the default is to use the full systolic array, which is equivalent to tile_position=(0, 0) and tile_size=(128, 128). The values in tile_position and tile_size tuples can be integers or affine expressions.

Memory types.

The nc_matmul_mx instruction must read inputs from SBUF and write outputs to PSUM. Therefore, the stationary, moving, stationary_scale, and moving_scale must be SBUF tiles, and dst tile must be a PSUM tile.

The psum_accumulate_flag controls whether the matmul result data should overwrite or accumulate on top of the dst PSUM tile. Multiple nisa.nc_matmul instructions accumulating into the same PSUM tile can form an accumulation group before the PSUM tile content is evicted back to SBUF. The

  • bit[0] of psum_accumulate_flag: if set, indicates this nisa.nc_matmul call is the first instruction in the accumulation group. The matmul result should overwrite the existing content in the dst PSUM tile.

  • bit[1] of psum_accumulate_flag: if set, indicates this nisa.nc_matmul call is the last instruction in the accumulation group. The matmul result should accumulate to the existing content in the dst PSUM tile.

  • bit[2] of psum_accumulate_flag: if set, indicates this nisa.nc_matmul call is the first instruction in the accumulation group. However, the matmul result should accumulate to the existing content in the dst PSUM tile.

nisa.nc_matmul calls that are not the first or last instruction of an accumulation group should not set any bit: psum_accumulate_flag=0.

Data types.

The input stationary and moving tiles must be float8_e5m2_x4, float8_e4m3fn_x4, or float4_e2m1fn_x4 (4-packed quantized data types). The stationary_scale and moving_scale tiles must be uint8. The dst tile can be float32 or bfloat16.

The 4-packed data types (float8_e5m2_x4/float8_e4m3fn_x4/float4_e2m1fn_x4) pack multiple quantized values into single elements. These packed data types are required because 4 microscaling quantized data values share 1 scale value and must operate together as a compact group.

Layout.

The contraction dimension of the matrix multiplication is along the partition dimension of stationary and moving tensors and also the x4 dimension within each packed data type element (float8_e5m2_x4, float8_e4m3fn_x4, or float4_e2m1fn_x4).

The free dimension of the stationary tile matches the partition dimension of the output dst tile in size, while the free dimension of the moving tile matches the free dimension of the dst tile in size.

The scale tensors follow a special layout requirement. See more details in nisa.quantize_mx API doc.

Tile size

  • The partition dimension size of stationary and moving must be identical and be a multiple of 32, not exceeding 128.

  • The free dimension size of stationary must be even and not exceed 128.

  • The free dimension size of moving must not exceed 512 when dst is in float32 or 1024 when dst is in bfloat16.

  • The scale tensors have partition dimensions that depend on whether the data tensors span multiple quadrants. See more details in nisa.quantize_mx API doc.

Parameters:
  • dst – the matrix multiplication output (PSUM tile)

  • stationary – the stationary quantized matrix (SBUF tile)

  • moving – the moving quantized matrix (SBUF tile)

  • stationary_scale – the dequantization scales for stationary matrix (SBUF tile)

  • moving_scale – the dequantization scales for moving matrix (SBUF tile)

  • tile_position – a 2D tuple (start_row, start_column) to control starting row and column in Tensor Engine tiling mode

  • tile_size – a 2D tuple (row_size, column_size) to control row and column tile sizes in Tensor Engine tiling mode

  • psum_accumulate_flag – controls PSUM near-memory accumulation in the dst tile