This document is relevant for: Trn2, Trn3
Build All To All V Metadata Kernel API Reference#
Build metadata buffer for all_to_all_v collective from MoE routing decisions.
Computes per-rank send counts and displacements from expert assignments.
Background#
The build_all_to_all_v_metadata kernel builds the metadata buffer required by the all_to_all_v collective operation from MoE routing decisions, computing per-rank send counts and displacements from expert assignments.
API Reference#
Source code for this kernel API can be found at: build_all_to_all_v_metadata.py
build_all_to_all_v_metadata#
- nkilib.experimental.subkernels.build_all_to_all_v_metadata(expert_index: nl.ndarray, replica_group_size: int, E: int, recv_counts_known: bool = False, has_rdispls: bool = False)#
Build metadata buffer for all_to_all_v collective from MoE routing decisions.
- Parameters:
expert_index (
nl.ndarray) – [T, K] int32 HBM tensor indicating the K experts each token is routed to.replica_group_size (
int) – Size of replica group for all_to_all_v collective.E (
int) – Number of global experts.recv_counts_known (
bool) – Not currently supported; when True, metadata includes recv counts.has_rdispls (
bool) – Not currently supported; when True, metadata includes recv displacements.
- Returns:
[n_rows, replica_group_size] uint32 HBM tensor. n_rows is 4 when has_rdispls=True, 3 otherwise. Row 0: send counts, Row 1: send displacements, Row 2: recv counts (zeros), Row 3 (optional): recv displacements (zeros).
- Return type:
nl.ndarray
Dimensions:
T: Number of tokens.
K: Top-K experts per token.
This document is relevant for: Trn2, Trn3