This document is relevant for: Trn2, Trn3

MLP Forward MXFP8 Kernel API Reference#

Fused gate/up + SiLU + multiply + down projection using TensorDescriptors.

Uses HBM for the intermediate tensor. All loads use load_and_quantize_tile with TileLocation/TensorDescriptor. For each M-block of TILES_IN_BLOCK_M tiles: Phase 1: gate/up matmul -> SiLU(gate) * up -> write intermediate to HBM Phase 2: read intermediate from HBM via DGT, matmul with down weights

Background#

The compute_fused_gate_up_down_mxfp8 kernel implements a fused SwiGLU MLP forward pass (gate/up projection, SiLU activation, element-wise multiply, and down projection) using MXFP8 quantized matmuls with TensorDescriptors.

API Reference#

Source code for this kernel API can be found at: mlp_fwd_mxfp8_kernel.py

compute_fused_gate_up_down_mxfp8#

nkilib.experimental.mlp_mxfp8.mlp_fwd_mxfp8.compute_fused_gate_up_down_mxfp8(hidden_td: TensorDescriptor, gate_up_td: TensorDescriptor, down_w_td: TensorDescriptor, int_td: TensorDescriptor, output_td: TensorDescriptor, s_base_offset: int, dtype, TILES_IN_BLOCK_M: int = 8, TILES_IN_BLOCK_N_GU: int = 1, TILES_IN_BLOCK_K_GU: int = 8, TILES_IN_BLOCK_M_DOWN: int = 8, TILES_IN_BLOCK_N_DOWN: int = 1, TILES_IN_BLOCK_K_DOWN: int = 8, save_gate_pre_td: TensorDescriptor = None, save_gate_act_td: TensorDescriptor = None, save_up_td: TensorDescriptor = None, spill_reload: bool = True, use_scale_packing: bool = True, run_with_lnc2: bool = True)#

Fused gate/up + SiLU + multiply + down projection using TensorDescriptors.

Parameters:
  • hidden_td (TensorDescriptor) – [S, H], input hidden states (is_f_by_k=True).

  • gate_up_td (TensorDescriptor) – [2I, H], fused gate+up weight matrix (is_f_by_k=True).

  • down_w_td (TensorDescriptor) – [H, I], down projection weights (is_f_by_k=True).

  • int_td (TensorDescriptor) – [S, I], scratch buffer for gated intermediate activations (is_f_by_k=True).

  • output_td (TensorDescriptor) – [S_local, H], output buffer (may be a slice for LNC sharding).

  • s_base_offset (int) – Row offset into the full [S, …] tensors for this LNC core.

  • dtype – Output data type (e.g. nl.bfloat16).

  • TILES_IN_BLOCK_M (int) – Number of M tiles per block for gate/up phase.

  • TILES_IN_BLOCK_N_GU (int) – Number of N tiles per block for gate/up phase.

  • TILES_IN_BLOCK_K_GU (int) – Number of K tiles per block for gate/up phase.

  • TILES_IN_BLOCK_M_DOWN (int) – Number of M tiles per block for down phase.

  • TILES_IN_BLOCK_N_DOWN (int) – Number of N tiles per block for down phase.

  • TILES_IN_BLOCK_K_DOWN (int) – Number of K tiles per block for down phase.

  • save_gate_pre_td (TensorDescriptor) – [S, I], optional TD to checkpoint gate pre-activation, or None.

  • save_gate_act_td (TensorDescriptor) – [S, I], optional TD to checkpoint SiLU(gate_pre), or None.

  • save_up_td (TensorDescriptor) – [S, I], optional TD to checkpoint up projection, or None.

Dimensions:

  • S: Sequence length (number of tokens).

  • H: Hidden dimension size.

mlp_forward_mxfp8_nki#

nkilib.experimental.mlp_mxfp8.mlp_fwd_mxfp8.mlp_forward_mxfp8_nki(hidden: nl.ndarray, gate_up_weights: nl.ndarray, down_weights: nl.ndarray, intermediate_hbm: nl.ndarray, run_with_lnc2: bool = True, gate_up_tiles_m: int = 8, gate_up_tiles_n: int = 1, gate_up_tiles_k: int = 8, down_tiles_m: int = 8, down_tiles_n: int = 1, down_tiles_k: int = 8, fp8_x4_dtype=float8_e4m3fn_x4, save_gate_pre: nl.ndarray = None, save_gate_act: nl.ndarray = None, save_up: nl.ndarray = None, save_hidden: nl.ndarray = None, dtype=nl.bfloat16, spill_reload: bool = True, use_scale_packing: bool = True, hidden_scales: nl.ndarray = None, gate_up_scales: nl.ndarray = None, down_scales: nl.ndarray = None, hidden_is_swizzled: bool = False, gate_up_is_swizzled: bool = False, down_is_swizzled: bool = False) nl.ndarray#

MXFP8 SwiGLU MLP forward pass with optional activation checkpointing.

Parameters:
  • hidden (nl.ndarray) – [S, H], input hidden states.

  • gate_up_weights (nl.ndarray) – [2I, H], fused weight matrix — rows [0:I] = W_gate, rows [I:2I] = W_up.

  • down_weights (nl.ndarray) – [H, I], down projection weights (W_down).

  • intermediate_hbm (nl.ndarray) – [S, I], scratch buffer for gated intermediate activations.

  • run_with_lnc2 (bool) – Whether to shard across 2 LNC cores.

  • gate_up_tiles_m (int) – Number of M tiles per block for gate/up phase.

  • gate_up_tiles_n (int) – Number of N tiles per block for gate/up phase.

  • gate_up_tiles_k (int) – Number of K tiles per block for gate/up phase.

  • down_tiles_m (int) – Number of M tiles per block for down phase.

  • down_tiles_n (int) – Number of N tiles per block for down phase.

  • down_tiles_k (int) – Number of K tiles per block for down phase.

  • fp8_x4_dtype – MXFP8 quantized data type for nc_matmul_mx.

  • save_gate_pre (nl.ndarray) – [S, I], HBM buffer to checkpoint gate pre-activation, or None.

  • save_gate_act (nl.ndarray) – [S, I], HBM buffer to checkpoint SiLU(gate_pre), or None.

  • save_up (nl.ndarray) – [S, I], HBM buffer to checkpoint up projection, or None.

  • save_hidden (nl.ndarray) – [S, I], HBM buffer to checkpoint gate_act * up, or None (same data as intermediate_hbm but kept as a separate named output for clarity in the fwd/bwd contract).

  • dtype – Output data type (e.g. nl.bfloat16).

  • spill_reload (bool) – Whether to spill quantized operands to HBM for reload across K-blocks.

  • use_scale_packing (bool) – Whether to pack MXFP8 scales into compact format.

  • hidden_scales (nl.ndarray) – MXFP8 scales for pre-quantized hidden, or None (raw BF16).

  • gate_up_scales (nl.ndarray) – MXFP8 scales for pre-quantized gate_up_weights, or None.

  • down_scales (nl.ndarray) – MXFP8 scales for pre-quantized down_weights, or None.

  • hidden_is_swizzled (bool) – True if hidden is pre-swizzled [K/4, F*4] BF16.

  • gate_up_is_swizzled (bool) – True if gate_up_weights is pre-swizzled.

  • down_is_swizzled (bool) – True if down_weights is pre-swizzled.

Returns:

[S, H], MLP output hidden states.

Return type:

nl.ndarray

Dimensions:

  • S: Sequence length (number of tokens).

  • H: Hidden dimension size.

This document is relevant for: Trn2, Trn3