This document is relevant for: Trn2, Trn3
Matmul MXFP8 Kernel API Reference#
Performs matrix multiplication with MXFP8 quantization.
This kernel implements efficient matrix multiplication using MXFP8 quantization format, supporting both pre-quantized inputs and automatic quantization from BF16. The kernel uses hardware-optimized tiling and supports LNC2 parallelization for improved throughput.
Background#
The matmul_mxfp8 kernel implements efficient matrix multiplication using MXFP8 quantization format, supporting both pre-quantized inputs and automatic quantization from BF16. It uses hardware-optimized tiling and supports LNC2 parallelization for improved throughput.
API Reference#
Source code for this kernel API can be found at: matmul_mxfp8_generic_kernel.py
matmul_mxfp8#
- nkilib.experimental.matmul_mxfp8.matmul_mxfp8(lhs, rhs, TILES_IN_BLOCK_M: int = None, TILES_IN_BLOCK_N: int = None, TILES_IN_BLOCK_K: int = None, TILES_IN_LOAD_M: int = None, TILES_IN_LOAD_N: int = None, lhs_matmul_tile_shape_logical: tuple = None, rhs_matmul_tile_shape_logical: tuple = None, block_loop_order: str = 'mnk', tile_loop_order: str = 'mnk', float8_dtype: str = 'float8_e5m2', output_dtype=nl.float32, run_with_lnc2: bool = True, lnc_2_shard_rhs: bool = True, lhs_scales=None, rhs_scales=None, use_scale_packing: bool = False, spill_reload: bool = False, lhs_is_swizzled: bool = True, rhs_is_swizzled: bool = True) nl.ndarray#
Performs matrix multiplication with MXFP8 quantization.
- Parameters:
lhs – Left-hand side matrix, either BF16 tensor or tuple (data, scales) for pre-quantized MXFP8
rhs – Right-hand side matrix, either BF16 tensor or tuple (data, scales) for pre-quantized MXFP8
TILES_IN_BLOCK_M (
int) – Number of matmul tiles per block in M dimension (auto-generated if None)TILES_IN_BLOCK_N (
int) – Number of matmul tiles per block in N dimension (auto-generated if None)TILES_IN_BLOCK_K (
int) – Number of matmul tiles per block in K dimension (auto-generated if None)TILES_IN_LOAD_M (
int) – Number of tiles to load at once in M dimension (auto-generated if None)TILES_IN_LOAD_N (
int) – Number of tiles to load at once in N dimension (auto-generated if None)lhs_matmul_tile_shape_logical (
tuple) – LHS tile shape (TILE_K, TILE_M) in logical space (auto-generated if None)rhs_matmul_tile_shape_logical (
tuple) – RHS tile shape (TILE_K, TILE_N) in logical space (auto-generated if None)block_loop_order (
str) – Block processing order, default ‘mnk’tile_loop_order (
str) – Tile processing order within blocks, default ‘mnk’float8_dtype (
str) – FP8 dtype for quantization, default “float8_e5m2”output_dtype – Output data type, default nl.float32
run_with_lnc2 (
bool) – Enable LNC2 parallelization across 2 cores, default Truelnc_2_shard_rhs (
bool) – When run_with_lnc2=True, shard on N dimension (RHS) if True, or shard on M dimension (LHS) if False. Default True.lhs_scales – Optional pre-computed scales for LHS
rhs_scales – Optional pre-computed scales for RHS
use_scale_packing (
bool) – If True and inputs are pre-quantized, assert that scales are packed, default Falsespill_reload (
bool) – If True, each quantized block will be written to HBM and on every subsequent load, this spilled block will be reloaded.lhs_is_swizzled (
bool) – Whether LHS BF16 tensor is pre-swizzled [K/4, M*4], default True. If False, expects [M, K] layout.rhs_is_swizzled (
bool) – Whether RHS BF16 tensor is pre-swizzled [K/4, N*4], default True. If False, expects [N, K] layout.
Notes:
Supports non-divisible tensor shapes using dynamic slicing (nl.ds)
Auto-generates optimal tiling parameters when not specified
LNC2 mode requires at least 2 blocks in N dimension
Pre-quantized inputs must be in MXFP8 format (data, scales) tuple
When use_scale_packing=True, pre-quantized inputs must have packed scales
TODO: Specify intended usage range for optimal performance Physical vs Logical Dimensions:
Logical: Theoretical tensor dimensions [M, K] @ [K, N] for the matmul operation
Physical: Hardware storage format (depends on quantization and swizzling) * Pre-swizzled: [K//4, M*4] or [K//4, N*4] * Quantized: [K//4, M] or [K//4, N] Tiles (smallest processing unit):
Matmul Tile: Hardware matmul operation shape * LHS: [128, 128] physical, [512, 128] logical * RHS: [128, 512] physical, [512, 512] logical
Load Tile: Data loaded per matmul tile (varies by quantization state)
Quantize Tile: Input shape for quantization to produce one matmul tile Blocks (collection of tiles):
Group of tiles processed together
Must fit in SBUF (including load, quantize, and output buffers)
Accumulates results across K dimension before storing to HBM Non-Divisible Shape Handling:
Uses ceiling division for block counts
Applies nl.ds (dynamic slice) for boundary handling at load and store operations Example:: import nki.language as nl # Basic usage with BF16 inputs lhs = nl.ndarray((512, 1024), dtype=nl.bfloat16, buffer=nl.hbm) rhs = nl.ndarray((512, 2048), dtype=nl.bfloat16, buffer=nl.hbm) result = matmul_mxfp8( lhs=lhs, rhs=rhs, TILES_IN_BLOCK_M=2, TILES_IN_BLOCK_N=2, TILES_IN_BLOCK_K=1, TILES_IN_LOAD_M=1, TILES_IN_LOAD_N=1, lhs_matmul_tile_shape_logical=(512, 128), rhs_matmul_tile_shape_logical=(512, 512), ) # Usage with pre-quantized inputs (tuple of data and scales) lhs_quantized = (lhs_data, lhs_scales) result = matmul_mxfp8(lhs=lhs_quantized, rhs=rhs, …)
Dimensions:
M: Number of rows in left-hand side matrix (output rows)
K: Contraction dimension (columns in LHS, rows in RHS)
This document is relevant for: Trn2, Trn3