nki.isa.quantize_mx#

nki.isa.quantize_mx(dst, src, dst_scale, name=None)[source]#

Quantize FP16/BF16 data to MXFP8 tensors (both data and scales) using Vector Engine.

Note

Available only on NeuronCore-v4 and beyond.

The resulting MXFP8 tensors, dst and dst_scale are as defined in the OCP Microscaling standard. This instruction calculates the required scales for each group of 32 values in src, divides them by the calculated scale, and casts to the target MXFP8 datatype. The output layout is suitable for direct consumption by the nisa.nc_matmul_mx API running on Tensor Engine.

Memory types.

All input src and output tiles (dst and dst_scale) must be in SBUF.

Data types.

The input src tile must be float16 or bfloat16. The output dst tile must be float8_e5m2_x4 or float8_e4m3fn_x4 (4-packed FP8 data types). The dst_scale tile must be uint8.

The 4-packed data types (float8_e5m2_x4/float8_e4m3fn_x4) are 32-bit data types that pack four 8-bit float8_e5m2/float8_e4m3fn values.

Layout.

The quantization operates on groups of 32 elements from the input src tile, where each group consists of 8 partitions × 4 elements per partition. For each 32-element group, the instruction produces:

  • Quantized FP8 data in dst

  • One shared scale value in dst_scale per group

Logically, dst should have the same shape as src if dst is interpreted as a pure FP8 data type. However, in NKI, dst uses a custom 4-packed data type that packs four contiguous FP8 elements into a single float8_e5m2_x4/float8_e4m3fn_x4 element. Therefore, dst has one quarter of the element count per partition compared to that of src.

Logically, dst_scale should have 1/32 the element count of src due to the microscaling group size of 32. Physically, the dst_scale tensor follows a special SBUF quadrant (32 partitions) distribution pattern where scale values are distributed across multiple SBUF quadrants while maintaining the same partition offset at each quadrant. Within each SBUF quadrant, a 32-partition slice of src tile produces 32//8 = 4 partitions worth of scale, where 8 is due to each group consisted of 8 partitions × 4 elements per partition. The number of scales per partition is 1/4 of the free dimension size of the src tile. Different SBUF quadrants of scales are produced in parallel, with the scales written to the first (or second) 8 partitions of each SBUF quadrant. In other words, the dst_scale must be placed in the first 16 partitions of each SBUF quadrant. The dst_scale tile declaration must always occupy a multiple 32 partitions, even though not all partitions can be filled with scale values by nisa.quantize_mx.

Tile size.

  • The partition dimension size of src must be a multiple of 32 and must not exceed 128.

  • The free dimension size of src must be a multiple of 4 and must not exceed the physical size of each SBUF partition.

  • The dst tile has the same partition dimension size as src but a free dimension size that is 1/4 of src free dimension size due to the special 4-packed FP8 data types.

  • The dst_scale tile partition dimension depends on whether src spans multiple SBUF quadrants.
    • If src occupies only 32 partitions, dst_scale will occupy 4 partitions.

    • Otherwise, dst_scale will occupy the same number of partitions as src.

Parameters:
  • dst – the quantized MXFP8 output tile

  • src – the input FP16/BF16 tile to be quantized

  • dst_scale – the output scale tile