This document is relevant for: Trn2, Trn3

MLP Backward MXFP8 Kernel API Reference#

Return (num_cores, shard_id) for LNC2 sharding.

Background#

The get_program_sharding_info kernel returns the LNC2 sharding configuration (num_cores, shard_id), used by the MXFP8 MLP backward pass to distribute computation across logical cores.

API Reference#

Source code for this kernel API can be found at: mlp_bwd_mxfp8_kernel.py

get_program_sharding_info#

nkilib.experimental.mlp_mxfp8.mlp_bwd_mxfp8.get_program_sharding_info(run_with_lnc2: bool) tuple#

Return (num_cores, shard_id) for LNC2 sharding.

compute_phase1_down_proj_mm_grad_mxfp8#

nkilib.experimental.mlp_mxfp8.mlp_bwd_mxfp8.compute_phase1_down_proj_mm_grad_mxfp8(output_grad_td: TensorDescriptor, gate_pre_td: TensorDescriptor, gate_act_td: TensorDescriptor, up_td: TensorDescriptor, d_gate_td: TensorDescriptor, d_up_td: TensorDescriptor, scratch_td: TensorDescriptor, down_weight_td: TensorDescriptor, s_base: int, dtype: type, fp8_x4_dtype: type, TILES_IN_BLOCK_M: int = 8, TILES_IN_BLOCK_N: int = 1, TILES_IN_BLOCK_K: int = 8, spill_reload: bool = True, use_scale_packing: bool = True, run_with_lnc2: bool = True) None#

Phase 1: Compute gradient through the down projection and SwiGLU gate.

Parameters:
  • output_grad_td (TensorDescriptor) – [S, H], incoming gradient (is_f_by_k=True).

  • gate_pre_td (TensorDescriptor) – [S, I], checkpointed gate pre-activation.

  • gate_act_td (TensorDescriptor) – [S, I], checkpointed gate post-activation.

  • up_td (TensorDescriptor) – [S, I], checkpointed up projection.

  • d_gate_td (TensorDescriptor) – [S, I], output: gate gradient.

  • d_up_td (TensorDescriptor) – [S, I], output: up gradient.

  • scratch_td (TensorDescriptor) – [2I, S], output: transposed d_gate || d_up.

  • down_weight_td (TensorDescriptor) – [I, H], transposed down projection weights (is_f_by_k=True).

  • s_base (int) – Row offset into the full [S, …] tensors for this LNC core.

  • dtype (type) – Data type for computation (nl.bfloat16).

  • fp8_x4_dtype (type) – MXFP8 quantized data type (e.g. float8_e4m3fn_x4).

  • TILES_IN_BLOCK_M (int) – Number of M tiles per block.

  • TILES_IN_BLOCK_N (int) – Number of N tiles per block.

  • TILES_IN_BLOCK_K (int) – Number of K tiles to accumulate in PSUM.

compute_phase2_hidden_states_grad_mxfp8#

nkilib.experimental.mlp_mxfp8.mlp_bwd_mxfp8.compute_phase2_hidden_states_grad_mxfp8(hidden_states_grad_td: TensorDescriptor, gate_weight_td: TensorDescriptor, up_weight_td: TensorDescriptor, d_gate_td: TensorDescriptor, d_up_td: TensorDescriptor, s_base: int, dtype: type, fp8_x4_dtype: type, TILES_IN_BLOCK_M: int = 8, TILES_IN_BLOCK_N: int = 1, TILES_IN_BLOCK_K: int = 8, spill_reload: bool = True, use_scale_packing: bool = True, run_with_lnc2: bool = True) None#

Phase 2: Compute gradient w.r.t. input hidden states.

Parameters:
  • hidden_states_grad_td (TensorDescriptor) – [S, H], output: dL/d_hidden.

  • gate_weight_td (TensorDescriptor) – [H, I], transposed gate projection weights (is_f_by_k=True).

  • up_weight_td (TensorDescriptor) – [H, I], transposed up projection weights (is_f_by_k=True).

  • d_gate_td (TensorDescriptor) – [S, I], gate gradient (is_f_by_k=True).

  • d_up_td (TensorDescriptor) – [S, I], up gradient (is_f_by_k=True).

  • s_base (int) – Row offset for this LNC core’s shard.

  • dtype (type) – Data type for computation (nl.bfloat16).

  • fp8_x4_dtype (type) – MXFP8 quantized data type (e.g. float8_e4m3fn_x4).

  • TILES_IN_BLOCK_M (int) – Number of M tiles per block.

  • TILES_IN_BLOCK_N (int) – Number of N tiles per block.

  • TILES_IN_BLOCK_K (int) – Number of K tiles to accumulate in PSUM.

compute_phase3_gate_up_weight_grad_mxfp8#

nkilib.experimental.mlp_mxfp8.mlp_bwd_mxfp8.compute_phase3_gate_up_weight_grad_mxfp8(weight_grad_td: TensorDescriptor, hidden_states_T_td: TensorDescriptor, grad_T_td: TensorDescriptor, dtype: type, fp8_x4_dtype: type, TILES_IN_BLOCK_M: int = 4, TILES_IN_BLOCK_N: int = 1, TILES_IN_BLOCK_K: int = 8, spill_reload: bool = True, use_scale_packing: bool = True, run_with_lnc2: bool = True) None#

Phase 3: Compute gradient w.r.t. gate and up weight matrices as a single matmul.

Parameters:
  • weight_grad_td (TensorDescriptor) – [2I, H], output: [dW_gate; dW_up].

  • hidden_states_T_td (TensorDescriptor) – [H, S], transposed input hidden states (is_f_by_k=True).

  • grad_T_td (TensorDescriptor) – [2I, S], transposed gate+up gradients (is_f_by_k=True, is_col_parallel_sharded=True for LNC2).

  • dtype (type) – Data type for computation (nl.bfloat16).

  • fp8_x4_dtype (type) – MXFP8 quantized data type.

  • TILES_IN_BLOCK_M (int) – Number of M tiles per block.

  • TILES_IN_BLOCK_N (int) – Number of N tiles per block.

  • TILES_IN_BLOCK_K (int) – Number of K tiles to accumulate in PSUM.

Dimensions:

  • S: Sequence length.

  • H: Hidden dimension size.

compute_phase4_down_weight_grad_mxfp8#

nkilib.experimental.mlp_mxfp8.mlp_bwd_mxfp8.compute_phase4_down_weight_grad_mxfp8(down_weight_grad_td: TensorDescriptor, output_grad_T_td: TensorDescriptor, hidden_T_td: TensorDescriptor, h_base: int, dtype: type, fp8_x4_dtype: type, TILES_IN_BLOCK_M: int = 4, TILES_IN_BLOCK_N: int = 1, TILES_IN_BLOCK_K: int = 8, spill_reload: bool = True, use_scale_packing: bool = True, run_with_lnc2: bool = True) None#

Phase 4: Compute gradient w.r.t. down projection weight matrix.

Parameters:
  • down_weight_grad_td (TensorDescriptor) – [H, I], output: dW_down.

  • output_grad_T_td (TensorDescriptor) – [H, S], transposed output gradient (is_f_by_k=True).

  • hidden_T_td (TensorDescriptor) – [I, S], transposed intermediate activations (is_f_by_k=True).

  • h_base (int) – Row offset into the H dimension for this LNC core.

  • dtype (type) – Data type for computation (nl.bfloat16).

  • fp8_x4_dtype (type) – MXFP8 quantized data type.

  • TILES_IN_BLOCK_M (int) – Number of M tiles per block.

  • TILES_IN_BLOCK_N (int) – Number of N tiles per block.

  • TILES_IN_BLOCK_K (int) – Number of K tiles to accumulate in PSUM.

mlp_backward_mxfp8_base_nki#

nkilib.experimental.mlp_mxfp8.mlp_bwd_mxfp8.mlp_backward_mxfp8_base_nki(output_grad_td: TensorDescriptor, hidden_states_td: TensorDescriptor, gate_pre_td: TensorDescriptor, gate_act_td: TensorDescriptor, up_td: TensorDescriptor, hidden_td: TensorDescriptor, gate_weight_T_td: TensorDescriptor, up_weight_T_td: TensorDescriptor, down_weight_T_td: TensorDescriptor, d_gate_td: TensorDescriptor, d_up_td: TensorDescriptor, hidden_states_T_td: TensorDescriptor, output_grad_T_td: TensorDescriptor, hidden_T_td: TensorDescriptor, scratch_td: TensorDescriptor, hidden_states_grad_td: TensorDescriptor, weight_grad_td: TensorDescriptor, down_weight_grad_td: TensorDescriptor, run_with_lnc2: bool = True, phase1_tiles_m: int = 8, phase1_tiles_n: int = 1, phase1_tiles_k: int = 8, phase2_tiles_m: int = 8, phase2_tiles_n: int = 1, phase2_tiles_k: int = 8, phase3_tiles_m: int = 4, phase3_tiles_n: int = 1, phase3_tiles_k: int = 8, phase4_tiles_m: int = 4, phase4_tiles_n: int = 1, phase4_tiles_k: int = 8, fp8_x4_dtype: type = float8_e4m3fn_x4, spill_reload: bool = True, use_scale_packing: bool = True) tuple#

MXFP8 SwiGLU MLP backward pass (base kernel).

Parameters:
  • output_grad_td (TensorDescriptor) – [S, H], incoming gradient dL/d_output (is_f_by_k=True).

  • hidden_states_td (TensorDescriptor) – [S, H], original input (for phase 3 weight grad).

  • gate_pre_td (TensorDescriptor) – [S, I], gate pre-activation (before SiLU).

  • gate_act_td (TensorDescriptor) – [S, I], gate post-activation (SiLU(gate_pre)).

  • up_td (TensorDescriptor) – [S, I], up projection (hidden @ W_up.T).

  • hidden_td (TensorDescriptor) – [S, I], gated intermediate (gate_act * up, for phase 4).

  • gate_weight_T_td (TensorDescriptor) – [H, I], transposed gate projection weights.

  • up_weight_T_td (TensorDescriptor) – [H, I], transposed up projection weights.

  • down_weight_T_td (TensorDescriptor) – [I, H], transposed down projection weights.

  • d_gate_td (TensorDescriptor) – [S, I], scratch: gate gradient.

  • d_up_td (TensorDescriptor) – [S, I], scratch: up gradient.

  • hidden_states_T_td (TensorDescriptor) – [H, S], pre-transposed input hidden states.

  • output_grad_T_td (TensorDescriptor) – [H, S], pre-transposed output gradient.

  • hidden_T_td (TensorDescriptor) – [I, S], pre-transposed intermediate activations.

  • scratch_td (TensorDescriptor) – [2I, S], scratch: transposed d_gate || d_up.

  • hidden_states_grad_td (TensorDescriptor) – [S, H], output: dL/d_hidden.

  • weight_grad_td (TensorDescriptor) – [2I, H], output: fused [dW_gate; dW_up].

  • down_weight_grad_td (TensorDescriptor) – [H, I], output: dL/dW_down.

  • run_with_lnc2 (bool) – Whether to shard across 2 LNC cores.

  • phase1_tiles_m (int) – M blocking for phase 1.

  • phase1_tiles_n (int) – N blocking for phase 1.

  • phase1_tiles_k (int) – K blocking for phase 1.

  • phase2_tiles_m (int) – M blocking for phase 2.

  • phase2_tiles_n (int) – N blocking for phase 2.

  • phase2_tiles_k (int) – K blocking for phase 2.

  • phase3_tiles_m (int) – M blocking for phase 3.

  • phase3_tiles_n (int) – N blocking for phase 3.

  • phase3_tiles_k (int) – K blocking for phase 3.

  • phase4_tiles_m (int) – M blocking for phase 4.

  • phase4_tiles_n (int) – N blocking for phase 4.

  • phase4_tiles_k (int) – K blocking for phase 4.

  • fp8_x4_dtype (type) – MXFP8 quantized data type.

Returns:

(hidden_states_grad [S, H], gate_up_weight_grad [2I, H], down_weight_grad [H, I]).

Return type:

nl.ndarray

Dimensions:

  • S: Sequence length.

  • H: Hidden dimension size.

mlp_backward_mxfp8_nki#

nkilib.experimental.mlp_mxfp8.mlp_bwd_mxfp8.mlp_backward_mxfp8_nki(output_hidden_states_grad: nl.ndarray, hidden_states: nl.ndarray, gate_proj_weight_T: nl.ndarray, up_proj_weight_T: nl.ndarray, down_proj_weight_T: nl.ndarray, gate_up_weights: nl.ndarray, d_gate_scratch: nl.ndarray, d_up_scratch: nl.ndarray, hidden_states_T: nl.ndarray, output_grad_T: nl.ndarray, hidden_T: nl.ndarray, silu_up_mul_gate_grad_T_scratch: nl.ndarray, gate_pre_scratch: nl.ndarray, gate_act_scratch: nl.ndarray, up_scratch: nl.ndarray, hidden_scratch: nl.ndarray, gate_pre: nl.ndarray = None, gate_act: nl.ndarray = None, up: nl.ndarray = None, hidden: nl.ndarray = None, run_with_lnc2: bool = True, phase1_tiles_m: int = 8, phase1_tiles_n: int = 1, phase1_tiles_k: int = 8, phase2_tiles_m: int = 8, phase2_tiles_n: int = 1, phase2_tiles_k: int = 8, phase3_tiles_m: int = 4, phase3_tiles_n: int = 1, phase3_tiles_k: int = 8, phase4_tiles_m: int = 4, phase4_tiles_n: int = 1, phase4_tiles_k: int = 8, recompute_tiles_m: int = 8, recompute_tiles_n: int = 1, recompute_tiles_k: int = 8, fp8_x4_dtype: type = float8_e4m3fn_x4, spill_reload: bool = True, use_scale_packing: bool = True, output_grad_scales: nl.ndarray = None, output_grad_is_swizzled: bool = False, down_weight_scales: nl.ndarray = None, down_weight_is_swizzled: bool = False, gate_weight_scales: nl.ndarray = None, gate_weight_is_swizzled: bool = False, up_weight_scales: nl.ndarray = None, up_weight_is_swizzled: bool = False, hidden_states_T_scales: nl.ndarray = None, hidden_states_T_is_swizzled: bool = False, hidden_states_scales: nl.ndarray = None, hidden_states_is_swizzled: bool = False, gate_up_weights_scales: nl.ndarray = None, gate_up_weights_is_swizzled: bool = False, output_grad_T_scales: nl.ndarray = None, output_grad_T_is_swizzled: bool = False, hidden_T_scales: nl.ndarray = None, hidden_T_is_swizzled: bool = False) tuple#

MXFP8 SwiGLU MLP backward pass with activation checkpointing support.

Parameters:
  • output_hidden_states_grad (nl.ndarray) – [S, H], incoming gradient dL/d_output.

  • hidden_states (nl.ndarray) – [S, H], original input (for recompute + phase 3).

  • gate_proj_weight_T (nl.ndarray) – [H, I], transposed gate projection weights (phase 2).

  • up_proj_weight_T (nl.ndarray) – [H, I], transposed up projection weights (phase 2).

  • down_proj_weight_T (nl.ndarray) – [I, H], transposed down projection weights (phase 1).

  • gate_up_weights (nl.ndarray) – [2I, H], fused gate+up weights (for recompute).

  • d_gate_scratch (nl.ndarray) – [S, I], scratch: gate gradient.

  • d_up_scratch (nl.ndarray) – [S, I], scratch: up gradient.

  • hidden_states_T (nl.ndarray) – [H, S], pre-transposed input hidden states.

  • output_grad_T (nl.ndarray) – [H, S], pre-transposed output gradient.

  • hidden_T (nl.ndarray) – [I, S], pre-transposed intermediate activations.

  • silu_up_mul_gate_grad_T_scratch (nl.ndarray) – [2I, S], scratch: transposed d_gate || d_up.

  • gate_pre_scratch (nl.ndarray) – [S, I], scratch buffer for gate_pre.

  • gate_act_scratch (nl.ndarray) – [S, I], scratch buffer for gate_act.

  • up_scratch (nl.ndarray) – [S, I], scratch buffer for up.

  • hidden_scratch (nl.ndarray) – [S, I], scratch buffer for hidden.

  • gate_pre (nl.ndarray) – [S, I], checkpointed gate pre-activation, or None.

  • gate_act (nl.ndarray) – [S, I], checkpointed SiLU(gate_pre), or None.

  • up (nl.ndarray) – [S, I], checkpointed up projection, or None.

  • hidden (nl.ndarray) – [S, I], checkpointed gate_act * up, or None. Pre-swizzled/pre-quantized input support: Each matmul operand tensor accepts an optional *_scales (nl.ndarray) and *_is_swizzled (bool) pair. When both are default (None/False), the tensor is treated as unswizzled BF16 — identical to prior behavior. output_grad_scales, output_grad_is_swizzled: Phase 1 LHS (output_hidden_states_grad). down_weight_scales, down_weight_is_swizzled: Phase 1 RHS (down_proj_weight_T). gate_weight_scales, gate_weight_is_swizzled: Phase 2 RHS (gate_proj_weight_T). up_weight_scales, up_weight_is_swizzled: Phase 2 RHS (up_proj_weight_T). hidden_states_T_scales, hidden_states_T_is_swizzled: Phase 3 RHS (hidden_states_T). hidden_states_scales, hidden_states_is_swizzled: Recompute LHS (hidden_states). gate_up_weights_scales, gate_up_weights_is_swizzled: Recompute RHS (gate_up_weights). output_grad_T_scales, output_grad_T_is_swizzled: Phase 4 LHS (output_grad_T). hidden_T_scales, hidden_T_is_swizzled: Phase 4 RHS (hidden_T).

Returns:

(hidden_states_grad [S, H], gate_up_weight_grad [2I, H], down_proj_weight_grad [H, I]).

Return type:

nl.ndarray

Dimensions:

  • S: Sequence length.

  • H: Hidden dimension size.

This document is relevant for: Trn2, Trn3