This document is relevant for: Trn2, Trn3
MLP Forward MXFP8 Kernel API Reference#
Fused gate/up + SiLU + multiply + down projection using TensorDescriptors.
Uses HBM for the intermediate tensor. All loads use load_and_quantize_tile with TileLocation/TensorDescriptor. For each M-block of TILES_IN_BLOCK_M tiles: Phase 1: gate/up matmul -> SiLU(gate) * up -> write intermediate to HBM Phase 2: read intermediate from HBM via DGT, matmul with down weights
Background#
The compute_fused_gate_up_down_mxfp8 kernel implements a fused SwiGLU MLP forward pass (gate/up projection, SiLU activation, element-wise multiply, and down projection) using MXFP8 quantized matmuls with TensorDescriptors.
API Reference#
Source code for this kernel API can be found at: mlp_fwd_mxfp8_kernel.py
compute_fused_gate_up_down_mxfp8#
- nkilib.experimental.mlp_mxfp8.mlp_fwd_mxfp8.compute_fused_gate_up_down_mxfp8(hidden_td: TensorDescriptor, gate_up_td: TensorDescriptor, down_w_td: TensorDescriptor, int_td: TensorDescriptor, output_td: TensorDescriptor, s_base_offset: int, dtype, TILES_IN_BLOCK_M: int = 8, TILES_IN_BLOCK_N_GU: int = 1, TILES_IN_BLOCK_K_GU: int = 8, TILES_IN_BLOCK_M_DOWN: int = 8, TILES_IN_BLOCK_N_DOWN: int = 1, TILES_IN_BLOCK_K_DOWN: int = 8, save_gate_pre_td: TensorDescriptor = None, save_gate_act_td: TensorDescriptor = None, save_up_td: TensorDescriptor = None, spill_reload: bool = True, use_scale_packing: bool = True, run_with_lnc2: bool = True)#
Fused gate/up + SiLU + multiply + down projection using TensorDescriptors.
- Parameters:
hidden_td (
TensorDescriptor) – [S, H], input hidden states (is_f_by_k=True).gate_up_td (
TensorDescriptor) – [2I, H], fused gate+up weight matrix (is_f_by_k=True).down_w_td (
TensorDescriptor) – [H, I], down projection weights (is_f_by_k=True).int_td (
TensorDescriptor) – [S, I], scratch buffer for gated intermediate activations (is_f_by_k=True).output_td (
TensorDescriptor) – [S_local, H], output buffer (may be a slice for LNC sharding).s_base_offset (
int) – Row offset into the full [S, …] tensors for this LNC core.dtype – Output data type (e.g. nl.bfloat16).
TILES_IN_BLOCK_M (
int) – Number of M tiles per block for gate/up phase.TILES_IN_BLOCK_N_GU (
int) – Number of N tiles per block for gate/up phase.TILES_IN_BLOCK_K_GU (
int) – Number of K tiles per block for gate/up phase.TILES_IN_BLOCK_M_DOWN (
int) – Number of M tiles per block for down phase.TILES_IN_BLOCK_N_DOWN (
int) – Number of N tiles per block for down phase.TILES_IN_BLOCK_K_DOWN (
int) – Number of K tiles per block for down phase.save_gate_pre_td (
TensorDescriptor) – [S, I], optional TD to checkpoint gate pre-activation, or None.save_gate_act_td (
TensorDescriptor) – [S, I], optional TD to checkpoint SiLU(gate_pre), or None.save_up_td (
TensorDescriptor) – [S, I], optional TD to checkpoint up projection, or None.
Dimensions:
S: Sequence length (number of tokens).
H: Hidden dimension size.
mlp_forward_mxfp8_nki#
- nkilib.experimental.mlp_mxfp8.mlp_fwd_mxfp8.mlp_forward_mxfp8_nki(hidden: nl.ndarray, gate_up_weights: nl.ndarray, down_weights: nl.ndarray, intermediate_hbm: nl.ndarray, run_with_lnc2: bool = True, gate_up_tiles_m: int = 8, gate_up_tiles_n: int = 1, gate_up_tiles_k: int = 8, down_tiles_m: int = 8, down_tiles_n: int = 1, down_tiles_k: int = 8, fp8_x4_dtype=float8_e4m3fn_x4, save_gate_pre: nl.ndarray = None, save_gate_act: nl.ndarray = None, save_up: nl.ndarray = None, save_hidden: nl.ndarray = None, dtype=nl.bfloat16, spill_reload: bool = True, use_scale_packing: bool = True, hidden_scales: nl.ndarray = None, gate_up_scales: nl.ndarray = None, down_scales: nl.ndarray = None, hidden_is_swizzled: bool = False, gate_up_is_swizzled: bool = False, down_is_swizzled: bool = False) nl.ndarray#
MXFP8 SwiGLU MLP forward pass with optional activation checkpointing.
- Parameters:
hidden (
nl.ndarray) – [S, H], input hidden states.gate_up_weights (
nl.ndarray) – [2I, H], fused weight matrix — rows [0:I] = W_gate, rows [I:2I] = W_up.down_weights (
nl.ndarray) – [H, I], down projection weights (W_down).intermediate_hbm (
nl.ndarray) – [S, I], scratch buffer for gated intermediate activations.run_with_lnc2 (
bool) – Whether to shard across 2 LNC cores.gate_up_tiles_m (
int) – Number of M tiles per block for gate/up phase.gate_up_tiles_n (
int) – Number of N tiles per block for gate/up phase.gate_up_tiles_k (
int) – Number of K tiles per block for gate/up phase.down_tiles_m (
int) – Number of M tiles per block for down phase.down_tiles_n (
int) – Number of N tiles per block for down phase.down_tiles_k (
int) – Number of K tiles per block for down phase.fp8_x4_dtype – MXFP8 quantized data type for nc_matmul_mx.
save_gate_pre (
nl.ndarray) – [S, I], HBM buffer to checkpoint gate pre-activation, or None.save_gate_act (
nl.ndarray) – [S, I], HBM buffer to checkpoint SiLU(gate_pre), or None.save_up (
nl.ndarray) – [S, I], HBM buffer to checkpoint up projection, or None.save_hidden (
nl.ndarray) – [S, I], HBM buffer to checkpoint gate_act * up, or None (same data as intermediate_hbm but kept as a separate named output for clarity in the fwd/bwd contract).dtype – Output data type (e.g. nl.bfloat16).
spill_reload (
bool) – Whether to spill quantized operands to HBM for reload across K-blocks.use_scale_packing (
bool) – Whether to pack MXFP8 scales into compact format.hidden_scales (
nl.ndarray) – MXFP8 scales for pre-quantized hidden, or None (raw BF16).gate_up_scales (
nl.ndarray) – MXFP8 scales for pre-quantized gate_up_weights, or None.down_scales (
nl.ndarray) – MXFP8 scales for pre-quantized down_weights, or None.hidden_is_swizzled (
bool) – True if hidden is pre-swizzled [K/4, F*4] BF16.gate_up_is_swizzled (
bool) – True if gate_up_weights is pre-swizzled.down_is_swizzled (
bool) – True if down_weights is pre-swizzled.
- Returns:
[S, H], MLP output hidden states.
- Return type:
nl.ndarray
Dimensions:
S: Sequence length (number of tokens).
H: Hidden dimension size.
This document is relevant for: Trn2, Trn3