This document is relevant for: Trn1, Trn2, Trn3

Top-K Reduce Kernel API Reference#

Computes MoE Top-K reduction across sparse all_to_all_v() collective output buffer.

The kernel supports:

  • Gathering scattered rows by packed global token index

  • Reduction along the K dimension

  • LNC sharding on the H dimension

Background#

The topk_reduce kernel gathers scattered rows by packed global token index and reduces along the K dimension. It is used to recombine expert outputs after an all_to_all_v() collective in Mixture of Experts models.

API Reference#

Source code for this kernel API can be found at: topk_reduce.py

topk_reduce#

nkilib.experimental.subkernels.topk_reduce(input: nl.ndarray, T: int, K: int)#

Compute MoE Top-K reduction across sparse all_to_all_v() collective output buffer.

Parameters:
  • input (nl.ndarray) – [TK_padded, H + 2]@HBM, bf16/fp16. Sparse input buffer containing T*K scattered outputs. Global token index is packed as int32 in the final 2x columns of each row.

  • T (int) – Total number of input tokens.

  • K (int) – Number of routed experts per token.

Returns:

[T, H]@HBM, bf16/fp16. Ordered and reduced output.

Return type:

nl.ndarray

Dimensions:

  • TK_padded: n_src_ranks * T, padded input row count

  • H: Hidden dimension size (must be divisible by LNC)

  • T: Total number of input tokens (up to 128)

This document is relevant for: Trn1, Trn2, Trn3