This document is relevant for: Trn2, Trn3
MLP Backward MXFP8 Kernel API Reference#
Return (num_cores, shard_id) for LNC2 sharding.
Background#
The get_program_sharding_info kernel returns the LNC2 sharding configuration (num_cores, shard_id), used by the MXFP8 MLP backward pass to distribute computation across logical cores.
API Reference#
Source code for this kernel API can be found at: mlp_bwd_mxfp8_kernel.py
get_program_sharding_info#
compute_phase1_down_proj_mm_grad_mxfp8#
- nkilib.experimental.mlp_mxfp8.mlp_bwd_mxfp8.compute_phase1_down_proj_mm_grad_mxfp8(output_grad_td: TensorDescriptor, gate_pre_td: TensorDescriptor, gate_act_td: TensorDescriptor, up_td: TensorDescriptor, d_gate_td: TensorDescriptor, d_up_td: TensorDescriptor, scratch_td: TensorDescriptor, down_weight_td: TensorDescriptor, s_base: int, dtype: type, fp8_x4_dtype: type, TILES_IN_BLOCK_M: int = 8, TILES_IN_BLOCK_N: int = 1, TILES_IN_BLOCK_K: int = 8, spill_reload: bool = True, use_scale_packing: bool = True, run_with_lnc2: bool = True) None#
Phase 1: Compute gradient through the down projection and SwiGLU gate.
- Parameters:
output_grad_td (
TensorDescriptor) – [S, H], incoming gradient (is_f_by_k=True).gate_pre_td (
TensorDescriptor) – [S, I], checkpointed gate pre-activation.gate_act_td (
TensorDescriptor) – [S, I], checkpointed gate post-activation.up_td (
TensorDescriptor) – [S, I], checkpointed up projection.d_gate_td (
TensorDescriptor) – [S, I], output: gate gradient.d_up_td (
TensorDescriptor) – [S, I], output: up gradient.scratch_td (
TensorDescriptor) – [2I, S], output: transposed d_gate || d_up.down_weight_td (
TensorDescriptor) – [I, H], transposed down projection weights (is_f_by_k=True).s_base (
int) – Row offset into the full [S, …] tensors for this LNC core.dtype (
type) – Data type for computation (nl.bfloat16).fp8_x4_dtype (
type) – MXFP8 quantized data type (e.g. float8_e4m3fn_x4).TILES_IN_BLOCK_M (
int) – Number of M tiles per block.TILES_IN_BLOCK_N (
int) – Number of N tiles per block.TILES_IN_BLOCK_K (
int) – Number of K tiles to accumulate in PSUM.
compute_phase3_gate_up_weight_grad_mxfp8#
- nkilib.experimental.mlp_mxfp8.mlp_bwd_mxfp8.compute_phase3_gate_up_weight_grad_mxfp8(weight_grad_td: TensorDescriptor, hidden_states_T_td: TensorDescriptor, grad_T_td: TensorDescriptor, dtype: type, fp8_x4_dtype: type, TILES_IN_BLOCK_M: int = 4, TILES_IN_BLOCK_N: int = 1, TILES_IN_BLOCK_K: int = 8, spill_reload: bool = True, use_scale_packing: bool = True, run_with_lnc2: bool = True) None#
Phase 3: Compute gradient w.r.t. gate and up weight matrices as a single matmul.
- Parameters:
weight_grad_td (
TensorDescriptor) – [2I, H], output: [dW_gate; dW_up].hidden_states_T_td (
TensorDescriptor) – [H, S], transposed input hidden states (is_f_by_k=True).grad_T_td (
TensorDescriptor) – [2I, S], transposed gate+up gradients (is_f_by_k=True, is_col_parallel_sharded=True for LNC2).dtype (
type) – Data type for computation (nl.bfloat16).fp8_x4_dtype (
type) – MXFP8 quantized data type.TILES_IN_BLOCK_M (
int) – Number of M tiles per block.TILES_IN_BLOCK_N (
int) – Number of N tiles per block.TILES_IN_BLOCK_K (
int) – Number of K tiles to accumulate in PSUM.
Dimensions:
S: Sequence length.
H: Hidden dimension size.
compute_phase4_down_weight_grad_mxfp8#
- nkilib.experimental.mlp_mxfp8.mlp_bwd_mxfp8.compute_phase4_down_weight_grad_mxfp8(down_weight_grad_td: TensorDescriptor, output_grad_T_td: TensorDescriptor, hidden_T_td: TensorDescriptor, h_base: int, dtype: type, fp8_x4_dtype: type, TILES_IN_BLOCK_M: int = 4, TILES_IN_BLOCK_N: int = 1, TILES_IN_BLOCK_K: int = 8, spill_reload: bool = True, use_scale_packing: bool = True, run_with_lnc2: bool = True) None#
Phase 4: Compute gradient w.r.t. down projection weight matrix.
- Parameters:
down_weight_grad_td (
TensorDescriptor) – [H, I], output: dW_down.output_grad_T_td (
TensorDescriptor) – [H, S], transposed output gradient (is_f_by_k=True).hidden_T_td (
TensorDescriptor) – [I, S], transposed intermediate activations (is_f_by_k=True).h_base (
int) – Row offset into the H dimension for this LNC core.dtype (
type) – Data type for computation (nl.bfloat16).fp8_x4_dtype (
type) – MXFP8 quantized data type.TILES_IN_BLOCK_M (
int) – Number of M tiles per block.TILES_IN_BLOCK_N (
int) – Number of N tiles per block.TILES_IN_BLOCK_K (
int) – Number of K tiles to accumulate in PSUM.
mlp_backward_mxfp8_base_nki#
- nkilib.experimental.mlp_mxfp8.mlp_bwd_mxfp8.mlp_backward_mxfp8_base_nki(output_grad_td: TensorDescriptor, hidden_states_td: TensorDescriptor, gate_pre_td: TensorDescriptor, gate_act_td: TensorDescriptor, up_td: TensorDescriptor, hidden_td: TensorDescriptor, gate_weight_T_td: TensorDescriptor, up_weight_T_td: TensorDescriptor, down_weight_T_td: TensorDescriptor, d_gate_td: TensorDescriptor, d_up_td: TensorDescriptor, hidden_states_T_td: TensorDescriptor, output_grad_T_td: TensorDescriptor, hidden_T_td: TensorDescriptor, scratch_td: TensorDescriptor, hidden_states_grad_td: TensorDescriptor, weight_grad_td: TensorDescriptor, down_weight_grad_td: TensorDescriptor, run_with_lnc2: bool = True, phase1_tiles_m: int = 8, phase1_tiles_n: int = 1, phase1_tiles_k: int = 8, phase2_tiles_m: int = 8, phase2_tiles_n: int = 1, phase2_tiles_k: int = 8, phase3_tiles_m: int = 4, phase3_tiles_n: int = 1, phase3_tiles_k: int = 8, phase4_tiles_m: int = 4, phase4_tiles_n: int = 1, phase4_tiles_k: int = 8, fp8_x4_dtype: type = float8_e4m3fn_x4, spill_reload: bool = True, use_scale_packing: bool = True) tuple#
MXFP8 SwiGLU MLP backward pass (base kernel).
- Parameters:
output_grad_td (
TensorDescriptor) – [S, H], incoming gradient dL/d_output (is_f_by_k=True).hidden_states_td (
TensorDescriptor) – [S, H], original input (for phase 3 weight grad).gate_pre_td (
TensorDescriptor) – [S, I], gate pre-activation (before SiLU).gate_act_td (
TensorDescriptor) – [S, I], gate post-activation (SiLU(gate_pre)).up_td (
TensorDescriptor) – [S, I], up projection (hidden @ W_up.T).hidden_td (
TensorDescriptor) – [S, I], gated intermediate (gate_act * up, for phase 4).gate_weight_T_td (
TensorDescriptor) – [H, I], transposed gate projection weights.up_weight_T_td (
TensorDescriptor) – [H, I], transposed up projection weights.down_weight_T_td (
TensorDescriptor) – [I, H], transposed down projection weights.d_gate_td (
TensorDescriptor) – [S, I], scratch: gate gradient.d_up_td (
TensorDescriptor) – [S, I], scratch: up gradient.hidden_states_T_td (
TensorDescriptor) – [H, S], pre-transposed input hidden states.output_grad_T_td (
TensorDescriptor) – [H, S], pre-transposed output gradient.hidden_T_td (
TensorDescriptor) – [I, S], pre-transposed intermediate activations.scratch_td (
TensorDescriptor) – [2I, S], scratch: transposed d_gate || d_up.hidden_states_grad_td (
TensorDescriptor) – [S, H], output: dL/d_hidden.weight_grad_td (
TensorDescriptor) – [2I, H], output: fused [dW_gate; dW_up].down_weight_grad_td (
TensorDescriptor) – [H, I], output: dL/dW_down.run_with_lnc2 (
bool) – Whether to shard across 2 LNC cores.phase1_tiles_m (
int) – M blocking for phase 1.phase1_tiles_n (
int) – N blocking for phase 1.phase1_tiles_k (
int) – K blocking for phase 1.phase2_tiles_m (
int) – M blocking for phase 2.phase2_tiles_n (
int) – N blocking for phase 2.phase2_tiles_k (
int) – K blocking for phase 2.phase3_tiles_m (
int) – M blocking for phase 3.phase3_tiles_n (
int) – N blocking for phase 3.phase3_tiles_k (
int) – K blocking for phase 3.phase4_tiles_m (
int) – M blocking for phase 4.phase4_tiles_n (
int) – N blocking for phase 4.phase4_tiles_k (
int) – K blocking for phase 4.fp8_x4_dtype (
type) – MXFP8 quantized data type.
- Returns:
(hidden_states_grad [S, H], gate_up_weight_grad [2I, H], down_weight_grad [H, I]).
- Return type:
nl.ndarray
Dimensions:
S: Sequence length.
H: Hidden dimension size.
mlp_backward_mxfp8_nki#
- nkilib.experimental.mlp_mxfp8.mlp_bwd_mxfp8.mlp_backward_mxfp8_nki(output_hidden_states_grad: nl.ndarray, hidden_states: nl.ndarray, gate_proj_weight_T: nl.ndarray, up_proj_weight_T: nl.ndarray, down_proj_weight_T: nl.ndarray, gate_up_weights: nl.ndarray, d_gate_scratch: nl.ndarray, d_up_scratch: nl.ndarray, hidden_states_T: nl.ndarray, output_grad_T: nl.ndarray, hidden_T: nl.ndarray, silu_up_mul_gate_grad_T_scratch: nl.ndarray, gate_pre_scratch: nl.ndarray, gate_act_scratch: nl.ndarray, up_scratch: nl.ndarray, hidden_scratch: nl.ndarray, gate_pre: nl.ndarray = None, gate_act: nl.ndarray = None, up: nl.ndarray = None, hidden: nl.ndarray = None, run_with_lnc2: bool = True, phase1_tiles_m: int = 8, phase1_tiles_n: int = 1, phase1_tiles_k: int = 8, phase2_tiles_m: int = 8, phase2_tiles_n: int = 1, phase2_tiles_k: int = 8, phase3_tiles_m: int = 4, phase3_tiles_n: int = 1, phase3_tiles_k: int = 8, phase4_tiles_m: int = 4, phase4_tiles_n: int = 1, phase4_tiles_k: int = 8, recompute_tiles_m: int = 8, recompute_tiles_n: int = 1, recompute_tiles_k: int = 8, fp8_x4_dtype: type = float8_e4m3fn_x4, spill_reload: bool = True, use_scale_packing: bool = True, output_grad_scales: nl.ndarray = None, output_grad_is_swizzled: bool = False, down_weight_scales: nl.ndarray = None, down_weight_is_swizzled: bool = False, gate_weight_scales: nl.ndarray = None, gate_weight_is_swizzled: bool = False, up_weight_scales: nl.ndarray = None, up_weight_is_swizzled: bool = False, hidden_states_T_scales: nl.ndarray = None, hidden_states_T_is_swizzled: bool = False, hidden_states_scales: nl.ndarray = None, hidden_states_is_swizzled: bool = False, gate_up_weights_scales: nl.ndarray = None, gate_up_weights_is_swizzled: bool = False, output_grad_T_scales: nl.ndarray = None, output_grad_T_is_swizzled: bool = False, hidden_T_scales: nl.ndarray = None, hidden_T_is_swizzled: bool = False) tuple#
MXFP8 SwiGLU MLP backward pass with activation checkpointing support.
- Parameters:
output_hidden_states_grad (
nl.ndarray) – [S, H], incoming gradient dL/d_output.hidden_states (
nl.ndarray) – [S, H], original input (for recompute + phase 3).gate_proj_weight_T (
nl.ndarray) – [H, I], transposed gate projection weights (phase 2).up_proj_weight_T (
nl.ndarray) – [H, I], transposed up projection weights (phase 2).down_proj_weight_T (
nl.ndarray) – [I, H], transposed down projection weights (phase 1).gate_up_weights (
nl.ndarray) – [2I, H], fused gate+up weights (for recompute).d_gate_scratch (
nl.ndarray) – [S, I], scratch: gate gradient.d_up_scratch (
nl.ndarray) – [S, I], scratch: up gradient.hidden_states_T (
nl.ndarray) – [H, S], pre-transposed input hidden states.output_grad_T (
nl.ndarray) – [H, S], pre-transposed output gradient.hidden_T (
nl.ndarray) – [I, S], pre-transposed intermediate activations.silu_up_mul_gate_grad_T_scratch (
nl.ndarray) – [2I, S], scratch: transposed d_gate || d_up.gate_pre_scratch (
nl.ndarray) – [S, I], scratch buffer for gate_pre.gate_act_scratch (
nl.ndarray) – [S, I], scratch buffer for gate_act.up_scratch (
nl.ndarray) – [S, I], scratch buffer for up.hidden_scratch (
nl.ndarray) – [S, I], scratch buffer for hidden.gate_pre (
nl.ndarray) – [S, I], checkpointed gate pre-activation, or None.gate_act (
nl.ndarray) – [S, I], checkpointed SiLU(gate_pre), or None.up (
nl.ndarray) – [S, I], checkpointed up projection, or None.hidden (
nl.ndarray) – [S, I], checkpointed gate_act * up, or None. Pre-swizzled/pre-quantized input support: Each matmul operand tensor accepts an optional*_scales(nl.ndarray) and*_is_swizzled(bool) pair. When both are default (None/False), the tensor is treated as unswizzled BF16 — identical to prior behavior. output_grad_scales, output_grad_is_swizzled: Phase 1 LHS (output_hidden_states_grad). down_weight_scales, down_weight_is_swizzled: Phase 1 RHS (down_proj_weight_T). gate_weight_scales, gate_weight_is_swizzled: Phase 2 RHS (gate_proj_weight_T). up_weight_scales, up_weight_is_swizzled: Phase 2 RHS (up_proj_weight_T). hidden_states_T_scales, hidden_states_T_is_swizzled: Phase 3 RHS (hidden_states_T). hidden_states_scales, hidden_states_is_swizzled: Recompute LHS (hidden_states). gate_up_weights_scales, gate_up_weights_is_swizzled: Recompute RHS (gate_up_weights). output_grad_T_scales, output_grad_T_is_swizzled: Phase 4 LHS (output_grad_T). hidden_T_scales, hidden_T_is_swizzled: Phase 4 RHS (hidden_T).
- Returns:
(hidden_states_grad [S, H], gate_up_weight_grad [2I, H], down_proj_weight_grad [H, I]).
- Return type:
nl.ndarray
Dimensions:
S: Sequence length.
H: Hidden dimension size.
This document is relevant for: Trn2, Trn3