Router Top-K Kernel API Reference#
Computes router logits, applies activation functions, and performs top-K selection with expert affinity scattering for Mixture of Experts (MoE) models.
The kernel supports:
Router logits computation (x @ w + bias)
Activation functions (SOFTMAX, SIGMOID)
Top-K expert selection (K ≤ 8)
Expert affinity scattering (one-hot or indirect DMA)
Multiple layout configurations and optimization modes
Column tiling for small token counts
LNC sharding across token dimension
Pre-norm and post-norm activation pipelines
L1 normalization of top-K probabilities
Background#
The Router Top-K kernel is a core component of Mixture of Experts (MoE) models, responsible for routing tokens to the most relevant experts. The kernel computes router logits by multiplying input tokens with a weight matrix, applies activation functions, selects the top-K experts for each token, and scatters the expert affinities to the full expert dimension.
The kernel is optimized for token counts T ≤ 2048, expert counts E ≤ 512, hidden dimensions H that are multiples of 128, and K ≤ 8 top experts per token. It supports both context encoding (CTE) with larger T and token generation (TKG) with T ≤ 128.
Pipeline Configurations:
The kernel supports multiple pipeline configurations:
(topK, ACT2, Scatter): Standard pipeline with post-topK activation
(ACT1, topK): Pre-norm activation before topK selection
(ACT1, topK, Norm, Scatter): Pre-norm with L1 normalization and scatter
API Reference#
Source code for this kernel API can be found at: router_topk.py
router_topk#
- nkilib.core.router_topk.router_topk(x, w, w_bias, router_logits, expert_affinities, expert_index, act_fn, k, x_hbm_layout, x_sb_layout, output_in_sbuf=False, router_pre_norm=True, norm_topk_prob=False, use_column_tiling=False, use_indirect_dma_scatter=False, return_eager_affi=False, use_PE_broadcast_w_bias=False, shard_on_tokens=False, skip_store_expert_index=False, skip_store_router_logits=False, x_input_in_sbuf=False, expert_affin_in_sb=False)#
Router top-K kernel for Mixture of Experts (MoE) models.
Computes router logits (x @ w + bias), applies activation functions, performs top-K selection, and scatters expert affinities. Supports multiple layout configurations, sharding strategies, and optimization modes.
- Parameters:
x (
nl.ndarray) – Input tensor. Shape depends onx_hbm_layoutandx_input_in_sbuf. If in HBM:[H, T]or[T, H]. If in SBUF: a permutation of[128, T, H/128].w (
nl.ndarray) – Weight tensor with shape[H, E]in HBMw_bias (
nl.ndarray) – Optional bias tensor with shape[1, E]or[E]in HBMrouter_logits (
nt.mutable_tensor) – Output router logits with shape[T, E]in HBMexpert_affinities (
nt.mutable_tensor) – Output expert affinities with shape[T, E]in HBM or SBUFexpert_index (
nt.mutable_tensor) – Output expert indices with shape[T, K]in HBM or SBUFact_fn (
common_types.RouterActFnType) – Activation function (SOFTMAX or SIGMOID)k (
int) – Number of top experts to select (must be ≤ 8)x_hbm_layout (
int) – Layout of x in HBM (0=[H,T], 1=[T,H])x_sb_layout (
int) – Layout of x in SBUF (0-3, see notes for details)output_in_sbuf (
bool, optional) – If True, outputs are in SBUF (requires T ≤ 128). Default is False.router_pre_norm (
bool, optional) – If True, apply activation before top-K (ACT1 pipeline). Default is True.norm_topk_prob (
bool, optional) – If True, normalize top-K probabilities with L1 norm. Default is False.use_column_tiling (
bool, optional) – Enable PE array column tiling for small T. Default is False.use_indirect_dma_scatter (
bool, optional) – Use indirect DMA for expert affinity scatter. Default is False.return_eager_affi (
bool, optional) – If True, return top-K affinities in addition to scattered. Default is False.use_PE_broadcast_w_bias (
bool, optional) – Use tensor engine for bias broadcast. Default is False.shard_on_tokens (
bool, optional) – Enable LNC sharding across token dimension. Default is False.skip_store_expert_index (
bool, optional) – Skip storing expert indices to HBM. Default is False.skip_store_router_logits (
bool, optional) – Skip storing router logits to HBM. Default is False.x_input_in_sbuf (
bool, optional) – If True, x is already in SBUF. Default is False.expert_affin_in_sb (
bool, optional) – If True, expert affinities output is in SBUF. Default is False.
- Returns:
List of
[router_logits, expert_index, expert_affinities, optional: expert_affinities_topk]- Return type:
list
Dimensions:
T: Total number of tokens
H: Hidden dimension size
E: Number of experts
K: Number of top experts to select per token
Constraints:
K must be ≤ 8
E must be ≤ 512 (gemm_moving_fmax)
H must be a multiple of 128
SIGMOID activation requires
use_indirect_dma_scatter=Truerouter_pre_normrequiresuse_indirect_dma_scatter=TrueWith
use_indirect_dma_scatter, T must be ≤ 128 or multiple of 128shard_on_tokensrequires n_prgs > 1 and T divisible by 2output_in_sbufrequires T ≤ 128
SBUF Layout Options (
x_sb_layout):0:
[128, T, H/128]- P-dim contains H elements with stride of H/1281:
[128, T, H/128]- P-dim with H/256 chunk interleaving2:
[128, T, H/128]- P-dim contains consecutive H elements3:
[128, H/128, T]- H-tiles in dim-1, T in dim-2
router_topk_input_x_load#
- nkilib.core.router_topk.router_topk_input_x_load(x, hbm_layout=0, sb_layout=1)#
Load input tensor x from HBM to SBUF with specified layout transformations.
Performs DMA transfer from HBM to SBUF with layout conversion based on hbm_layout and sb_layout parameters. Supports multiple layout combinations optimized for different access patterns in subsequent matmul operations.
- Parameters:
x (
nl.ndarray) – Input tensor in HBM. Shape[H, T]if hbm_layout=0,[T, H]if hbm_layout=1hbm_layout (
int, optional) – Layout of x in HBM (0=[H,T], 1=[T,H]). Default is 0.sb_layout (
int, optional) – Target layout in SBUF (0-3). Default is 1.
- Returns:
Input tensor in SBUF with transformed layout
- Return type:
nl.ndarray
Constraints:
H must be a multiple of 128
Supported combinations: (hbm_layout=0, sb_layout=3) and (hbm_layout=1, sb_layout=0/1/2)
router_topk_input_w_load#
- nkilib.core.router_topk.router_topk_input_w_load(w, x_sb_layout, name='')#
Load weight tensor w from HBM to SBUF with layout matching x tensor.
- Parameters:
w (
nl.ndarray) – Weight tensor with shape[H, E]in HBMx_sb_layout (
int) – Layout of x in SBUF (determines w layout)name (
str, optional) – Optional name for the tensor. Default is empty string.
- Returns:
Weight tensor in SBUF with appropriate layout
- Return type:
nl.ndarray
Implementation Details#
The kernel implementation includes several key optimizations:
Tiled Matrix Multiplication: Tiles computation on both H (contraction dimension) and T (token dimension) for efficient memory access and hardware utilization.
PE Array Column Tiling: For small token counts (T < 128), splits the PE array column-wise into multiple tiles (32, 64, or 128 columns) to enable parallel execution of independent matmuls.
LNC Sharding: Supports parallelization across 2 cores by sharding the token dimension. Each core processes T/2 tokens with automatic load balancing for non-divisible token counts.
Bias Broadcasting: Supports two methods for bias application:
Stream shuffle broadcast (default)
Tensor engine matmul with ones mask (
use_PE_broadcast_w_bias=True)
Top-K Selection: Uses hardware-accelerated
max8andnc_find_index8instructions to efficiently find top-8 values and their indices.Expert Affinity Scattering: Supports two scattering methods:
One-hot scatter: Uses mask-based selection with element-wise operations
Indirect DMA scatter: Uses dynamic indexing for efficient scatter to HBM
Activation Pipelines: Supports multiple activation pipeline configurations (ACT1, ACT2) with optional L1 normalization.
Memory Management: Carefully manages SBUF allocations with modular allocation and buffer reuse for intermediate tensors.
See Also#
Router Top-K PyTorch Reference