This document is relevant for: Trn2, Trn3
Quantize Mxfp8 Kernel API Reference#
Determine if packed scales should be stored for the current tile.
Background#
The should_store_packed_scales kernel determines whether packed scales should be stored for the current tile during MXFP8 block-wise quantization.
API Reference#
Source code for this kernel API can be found at: quantize_mxfp8.py
should_store_packed_scales#
- nkilib.experimental.quantize_mxfp8.should_store_packed_scales(tile_k_idx: int, total_num_tiles: int) bool#
Determine if packed scales should be stored for the current tile.
- Parameters:
tile_k_idx (
int) – Current tile index in K dimension (across all tiles including remainder)total_num_tiles (
int) – Total number of tiles including remainder tiles
- Returns:
True if scales should be stored, False otherwise
- Return type:
nl.ndarray
quantize_block_mxfp8_kernel#
- nkilib.experimental.quantize_mxfp8.quantize_block_mxfp8_kernel(src_tensor: nl.ndarray, return_fp8_dtype: str, run_with_lnc2: bool = False, enable_scale_packing: bool = True) tuple[nl.ndarray, nl.ndarray]#
Kernel for quantizing BF16 tensor to MXFP8 format with block-wise quantization.
- Parameters:
src_tensor (
nl.ndarray) – [F, K], Input tensor in BF16 format on HBMreturn_fp8_dtype (
str) – FP8 dtype string like “float8_e4m3fn” or “float8_e5m2”run_with_lnc2 (
bool) – Enable LNC2 parallelization along F dimension (default: False)enable_scale_packing (
bool) – Enable scale packing optimization (default: True)
- Returns:
[K // 4, F], Scales in uint8 format on HBM
- Return type:
nl.ndarray- Returns:
[K // 4, F * INTERLEAVE_FACTOR], Quantized data in FP8 format on HBM
- Return type:
nl.ndarray
Notes:
K dimension must be divisible by 512 for mxfp8 quantization
F dimension must be divisible by 8 for quantization
LNC2 splits work along F dimension; supports uneven splits
Dimensions:
F: Feature dimension (rows in input tensor)
This document is relevant for: Trn2, Trn3