This document is relevant for: Trn1, Trn2, Trn3
Top-K Reduce Kernel API Reference#
Computes MoE Top-K reduction across sparse all_to_all_v() collective output buffer.
The kernel supports:
Gathering scattered rows by packed global token index
Reduction along the K dimension
LNC sharding on the H dimension
Background#
The topk_reduce kernel gathers scattered rows by packed global token index and reduces along the K dimension. It is used to recombine expert outputs after an all_to_all_v() collective in Mixture of Experts models.
API Reference#
Source code for this kernel API can be found at: topk_reduce.py
topk_reduce#
- nkilib.experimental.subkernels.topk_reduce(input: nl.ndarray, T: int, K: int)#
Compute MoE Top-K reduction across sparse all_to_all_v() collective output buffer.
- Parameters:
input (
nl.ndarray) – [TK_padded, H + 2]@HBM, bf16/fp16. Sparse input buffer containing T*K scattered outputs. Global token index is packed as int32 in the final 2x columns of each row.T (
int) – Total number of input tokens.K (
int) – Number of routed experts per token.
- Returns:
[T, H]@HBM, bf16/fp16. Ordered and reduced output.
- Return type:
nl.ndarray
Dimensions:
TK_padded: n_src_ranks * T, padded input row count
H: Hidden dimension size (must be divisible by LNC)
T: Total number of input tokens (up to 128)
This document is relevant for: Trn1, Trn2, Trn3