This document is relevant for: Trn1, Trn2, Trn3

Fine-Grained All-Gather Kernel API Reference#

Performs fine-grained ring-based all-gather across ranks for TRN2.

The kernel supports:

  • Ring-based collective permute with double buffering

  • Both SBUF and HBM communication paths with automatic selection based on tensor sizes

  • Overlapped communication and data movement

Background#

The fine_grained_allgather kernel performs all-gather on the input tensor across ranks along the row dimension. It uses ring-based collective permute with double buffering to overlap communication and data movement.

API Reference#

Source code for this kernel API can be found at: fg_allgather.py

fine_grained_allgather#

nkilib.experimental.collectives.fine_grained_allgather(lhs: nl.ndarray, tp_degree: int, num_groups: int, force_hbm_cc: bool = False) nl.ndarray#

Fine-grained ring-based all-gather kernel for TRN2.

Parameters:
  • lhs (nl.ndarray) – [m, K], Input tensor, row-sharded across ranks.

  • tp_degree (int) – Tensor parallelism degree (number of ranks). Must be even. Supported values: 4, 8, 16, 32, 64, 128.

  • num_groups (int) – Number of replica groups for collective communication.

  • force_hbm_cc (bool) – If True, force HBM collective communication path even when SBUF path is feasible.

Returns:

[RANK_N, …], Fully gathered tensor in shared HBM. Shape depends on communication path (SBUF vs HBM).

Return type:

nl.ndarray

Notes:

  • tp_degree must be even.

  • M must be divisible by (RANK_N * LNC_N * CHANNEL_N).

  • Platform target is TRN2 only.

Dimensions:

  • m: Local rows per rank (before all-gather).

  • M: Total rows after all-gather (m * tp_degree).

This document is relevant for: Trn1, Trn2, Trn3