This document is relevant for: Trn1, Trn2, Trn3
Fine-Grained All-Gather Kernel API Reference#
Performs fine-grained ring-based all-gather across ranks for TRN2.
The kernel supports:
Ring-based collective permute with double buffering
Both SBUF and HBM communication paths with automatic selection based on tensor sizes
Overlapped communication and data movement
Background#
The fine_grained_allgather kernel performs all-gather on the input tensor across ranks along the row dimension. It uses ring-based collective permute with double buffering to overlap communication and data movement.
API Reference#
Source code for this kernel API can be found at: fg_allgather.py
fine_grained_allgather#
- nkilib.experimental.collectives.fine_grained_allgather(lhs: nl.ndarray, tp_degree: int, num_groups: int, force_hbm_cc: bool = False) nl.ndarray#
Fine-grained ring-based all-gather kernel for TRN2.
- Parameters:
lhs (
nl.ndarray) – [m, K], Input tensor, row-sharded across ranks.tp_degree (
int) – Tensor parallelism degree (number of ranks). Must be even. Supported values: 4, 8, 16, 32, 64, 128.num_groups (
int) – Number of replica groups for collective communication.force_hbm_cc (
bool) – If True, force HBM collective communication path even when SBUF path is feasible.
- Returns:
[RANK_N, …], Fully gathered tensor in shared HBM. Shape depends on communication path (SBUF vs HBM).
- Return type:
nl.ndarray
Notes:
tp_degree must be even.
M must be divisible by (RANK_N * LNC_N * CHANNEL_N).
Platform target is TRN2 only.
Dimensions:
m: Local rows per rank (before all-gather).
M: Total rows after all-gather (m * tp_degree).
This document is relevant for: Trn1, Trn2, Trn3