This document is relevant for: Trn2, Trn3
Permute Routed Tokens Kernel API Reference#
Sort tokens by expert and pack hidden states, affinities, and token indices into a [T*K, n_output_cols] buffer.
Background#
The permute_routed_tokens kernel sorts tokens by their assigned expert and packs hidden states, affinities, and token indices into a contiguous [T*K, n_output_cols] buffer for efficient MoE dispatch.
API Reference#
Source code for this kernel API can be found at: permute_routed_tokens.py
permute_routed_tokens#
- nkilib.experimental.subkernels.permute_routed_tokens(hidden_input: nl.ndarray, expert_index: nl.ndarray, expert_affinities_masked: nl.ndarray)#
Sort tokens by expert and pack hidden states, affinities, and token indices into a [T*K, n_output_cols] buffer.
- Parameters:
hidden_input (
nl.ndarray) – [T, n_input_cols], bf16 or fp8 HBM tensor of hidden states. When hidden states are fp8, each row contains packed scales.expert_index (
nl.ndarray) – [T, K], int32 HBM tensor of top-K expert indices per token.expert_affinities_masked (
nl.ndarray) – [T, E], bf16 HBM tensor of expert affinities, with zeros for non-routed token/expert pairs.
- Returns:
[T*K, n_output_cols], bf16 or fp8 HBM tensor where each row is [hidden_state, affinity, token_index] sorted by expert index.
- Return type:
nl.ndarray
Notes:
Requires T*K ≤ 128 (pmax) and K ∈ {1, 2, 4, 8}.
Dimensions:
T: Number of tokens.
H: Hidden dimension size.
n_input_cols: Number of columns in hidden_input. When hidden_input is bf16, n_cols=H. When hidden_input is fp8, n_cols contains H and may contain additional columns for quantization scales.
n_concat_cols: Number of columns corresponding to affinities (bf16) and token index (int32), when viewed as hidden_input dtype.
n_output_cols: n_input_cols + n_concat_cols
K: Top-K experts per token.
This document is relevant for: Trn2, Trn3