This document is relevant for: Trn2, Trn3

Quantize Mxfp8 Kernel API Reference#

Determine if packed scales should be stored for the current tile.

Background#

The should_store_packed_scales kernel determines whether packed scales should be stored for the current tile during MXFP8 block-wise quantization.

API Reference#

Source code for this kernel API can be found at: quantize_mxfp8.py

should_store_packed_scales#

nkilib.experimental.quantize_mxfp8.should_store_packed_scales(tile_k_idx: int, total_num_tiles: int) bool#

Determine if packed scales should be stored for the current tile.

Parameters:
  • tile_k_idx (int) – Current tile index in K dimension (across all tiles including remainder)

  • total_num_tiles (int) – Total number of tiles including remainder tiles

Returns:

True if scales should be stored, False otherwise

Return type:

nl.ndarray

quantize_block_mxfp8_kernel#

nkilib.experimental.quantize_mxfp8.quantize_block_mxfp8_kernel(src_tensor: nl.ndarray, return_fp8_dtype: str, run_with_lnc2: bool = False, enable_scale_packing: bool = True) tuple[nl.ndarray, nl.ndarray]#

Kernel for quantizing BF16 tensor to MXFP8 format with block-wise quantization.

Parameters:
  • src_tensor (nl.ndarray) – [F, K], Input tensor in BF16 format on HBM

  • return_fp8_dtype (str) – FP8 dtype string like “float8_e4m3fn” or “float8_e5m2”

  • run_with_lnc2 (bool) – Enable LNC2 parallelization along F dimension (default: False)

  • enable_scale_packing (bool) – Enable scale packing optimization (default: True)

Returns:

[K // 4, F], Scales in uint8 format on HBM

Return type:

nl.ndarray

Returns:

[K // 4, F * INTERLEAVE_FACTOR], Quantized data in FP8 format on HBM

Return type:

nl.ndarray

Notes:

  • K dimension must be divisible by 512 for mxfp8 quantization

  • F dimension must be divisible by 8 for quantization

  • LNC2 splits work along F dimension; supports uneven splits

Dimensions:

  • F: Feature dimension (rows in input tensor)

This document is relevant for: Trn2, Trn3