This document is relevant for: Trn2, Trn3

Build All To All V Metadata Kernel API Reference#

Build metadata buffer for all_to_all_v collective from MoE routing decisions.

Computes per-rank send counts and displacements from expert assignments.

Background#

The build_all_to_all_v_metadata kernel builds the metadata buffer required by the all_to_all_v collective operation from MoE routing decisions, computing per-rank send counts and displacements from expert assignments.

API Reference#

Source code for this kernel API can be found at: build_all_to_all_v_metadata.py

build_all_to_all_v_metadata#

nkilib.experimental.subkernels.build_all_to_all_v_metadata(expert_index: nl.ndarray, replica_group_size: int, E: int, recv_counts_known: bool = False, has_rdispls: bool = False)#

Build metadata buffer for all_to_all_v collective from MoE routing decisions.

Parameters:
  • expert_index (nl.ndarray) – [T, K] int32 HBM tensor indicating the K experts each token is routed to.

  • replica_group_size (int) – Size of replica group for all_to_all_v collective.

  • E (int) – Number of global experts.

  • recv_counts_known (bool) – Not currently supported; when True, metadata includes recv counts.

  • has_rdispls (bool) – Not currently supported; when True, metadata includes recv displacements.

Returns:

[n_rows, replica_group_size] uint32 HBM tensor. n_rows is 4 when has_rdispls=True, 3 otherwise. Row 0: send counts, Row 1: send displacements, Row 2: recv counts (zeros), Row 3 (optional): recv displacements (zeros).

Return type:

nl.ndarray

Dimensions:

  • T: Number of tokens.

  • K: Top-K experts per token.

This document is relevant for: Trn2, Trn3