Blockwise MM Backward Kernel API Reference#
[Experimental] Computes the backward pass for blockwise matrix multiplication in Mixture of Experts (MoE) layers, producing gradients for all parameters.
The kernel supports:
Gradient computation for hidden states, expert affinities, gate/up weights, and down weights
Optional bias gradient computation
Multiple sharding strategies (hidden dimension, intermediate dimension)
Affinity scaling on hidden or intermediate dimension
Gradient clamping for numerical stability
Various activation functions (SiLU, GELU, Swish)
Dropless MoE with variable block assignments per expert
Background#
The blockwise_mm_bwd kernel is the backward pass companion to the MoE CTE forward kernel. It computes gradients for all learnable parameters in a blockwise MoE layer by reversing the forward computation:
Down projection backward: Compute gradients for down projection weights and intermediate activations
Activation backward: Compute gradients through the activation function using checkpointed activations
Gate/Up projection backward: Compute gradients for gate and up projection weights
Hidden states backward: Compute gradients for input hidden states
Affinity backward: Compute gradients for expert affinities
The kernel uses activation checkpoints saved during the forward pass (gate_up_proj_act_checkpoint_T and down_proj_act_checkpoint) to avoid recomputation.
API Reference#
Source code for this kernel API can be found at: blockwise_mm_backward.py
blockwise_mm_bwd#
- nkilib.experimental.moe.bwd.blockwise_mm_bwd(hidden_states: nl.ndarray, expert_affinities_masked: nl.ndarray, gate_up_proj_weight: nl.ndarray, down_proj_weight: nl.ndarray, gate_up_proj_act_checkpoint_T: nl.ndarray, down_proj_act_checkpoint: nl.ndarray, token_position_to_id: nl.ndarray, block_to_expert: nl.ndarray, output_hidden_states_grad: nl.ndarray, block_size: int, skip_dma: SkipMode = None, compute_dtype: nki.dtype = nl.bfloat16, is_tensor_update_accumulating: bool = True, shard_option: ShardOption = ShardOption.SHARD_ON_HIDDEN, affinity_option: AffinityOption = AffinityOption.AFFINITY_ON_H, kernel_type_option: KernelTypeOption = KernelTypeOption.DROPLESS, clamp_limits: ClampLimits = None, bias: bool = False, activation_type: ActFnType = ActFnType.SiLU, block_tile_size: int = None) tuple#
Compute backward pass for blockwise MoE layer.
Computes gradients for all parameters in a Mixture of Experts layer using blockwise matrix multiplication. Optimized for dropless MoE with variable block assignments per expert.
- Parameters:
hidden_states (
nl.ndarray) – Input hidden states tensor with shape[T, H]in HBM.expert_affinities_masked (
nl.ndarray) – Expert affinities with shape[T * E, 1]in HBM.gate_up_proj_weight (
nl.ndarray) – Gate and up projection weights with shape[E, H, 2, I_TP]in HBM.down_proj_weight (
nl.ndarray) – Down projection weights with shape[E, I_TP, H]in HBM.gate_up_proj_act_checkpoint_T (
nl.ndarray) – Checkpointed gate/up activations from forward pass with shape[N, 2, I_TP, B].down_proj_act_checkpoint (
nl.ndarray) – Checkpointed down projection activations from forward pass with shape[N, B, H].token_position_to_id (
nl.ndarray) – Token position to block mapping with shape[N * B].block_to_expert (
nl.ndarray) – Expert index per block with shape[N, 1].output_hidden_states_grad (
nl.ndarray) – Upstream gradient from output with shape[T, H].block_size (
int) – Number of tokens per block. Must be one of: 128, 256, 512, 1024.skip_dma (
SkipMode, optional) – DMA skip mode for OOB handling. Default:SkipMode(False, False).compute_dtype (
nki.dtype) – Computation data type. Default:nl.bfloat16.is_tensor_update_accumulating (
bool) – Whether to accumulate into existing gradients. Default:True.shard_option (
ShardOption) – Sharding strategy.SHARD_ON_HIDDEN: shard across hidden dimension.SHARD_ON_INTERMEDIATE: shard across intermediate dimension.AUTO: auto-select. Default:SHARD_ON_HIDDEN.affinity_option (
AffinityOption) – Dimension for affinity scaling.AFFINITY_ON_H: scale on hidden dimension.AFFINITY_ON_I: scale on intermediate dimension. Default:AFFINITY_ON_H.kernel_type_option (
KernelTypeOption) – Token dropping strategy.DROPLESS: variable blocks per expert.DROPPING: fixed blocks per expert. Default:DROPLESS.clamp_limits (
ClampLimits, optional) – Gradient clamping limits for numerical stability. Containslinear_clamp_upper_limit,linear_clamp_lower_limit,non_linear_clamp_upper_limit,non_linear_clamp_lower_limit.bias (
bool) – Whether to compute bias gradients. Default:False.activation_type (
ActFnType) – Activation function type. Default:SiLU.block_tile_size (
int, optional) – Optional tile size override for block processing.
- Returns:
Tuple of gradient tensors. When
bias=False:(hidden_states_grad, expert_affinities_masked_grad, gate_up_proj_weight_grad, down_proj_weight_grad). Whenbias=True: additionally includesgate_and_up_proj_bias_gradanddown_proj_bias_grad.- Return type:
tuple
Dimensions:
T: Total number of input tokens
H: Hidden dimension size
I_TP: Intermediate size / tensor parallel degree
E: Number of experts
B: Block size (tokens per block)
N: Number of blocks
Supported Data Types:
Input: bfloat16, float16
Constraints:
block_sizemust be one of: 128, 256, 512, 1024H must be divisible by the number of shards for LNC sharding
Currently only supports
DROPLESSkernel typeRequires activation checkpoints from the forward pass (
gate_up_proj_act_checkpoint_Tanddown_proj_act_checkpoint)
Implementation Details#
The kernel implementation includes several key optimizations:
Sharding Strategies: Supports sharding across hidden dimension (simpler, no H-tiling) or intermediate dimension (better memory efficiency) for LNC2 parallelism.
Activation Checkpointing: Uses saved activations from the forward pass to avoid recomputation during backward, trading memory for compute.
Blockwise Processing: Processes tokens in blocks matching the forward pass structure, enabling efficient gradient accumulation across experts.
Gradient Clamping: Optional clamping of gradients for numerical stability during training.
Affinity Gradient Computation: Computes gradients for expert routing weights, enabling end-to-end training of the router.