This document is relevant for: Trn1, Trn2, Trn3

FGCC (All-Gather + Matmul) Kernel API Reference#

Performs fused all-gather and matrix multiplication (FGCC) for TRN2.

The kernel supports:

  • All-gather on left-hand side tensor across ranks

  • Matrix multiplication with column-sharded right-hand side tensor

  • Ring-based collective permute overlapped with compute

  • Both SBUF and HBM communication paths with automatic selection

Background#

The allgather_compute_matmul kernel performs all-gather on the left-hand side tensor across ranks, then computes matrix multiplication with a column-sharded right-hand side tensor. Communication is overlapped with compute using ring-based collective permute.

API Reference#

Source code for this kernel API can be found at: fgcc.py

allgather_compute_matmul#

nkilib.experimental.collectives.allgather_compute_matmul(lhs: nl.ndarray, rhs: nl.ndarray, tp_degree: int, num_groups: int, force_hbm_cc: bool = False) nl.ndarray#

Fine grained all-gather and matrix multiplication (FGCC) kernel for TRN2.

Parameters:
  • lhs (nl.ndarray) – [m, K], Left-hand side tensor, row-sharded across ranks.

  • rhs (nl.ndarray) – [K, N], Right-hand side tensor, column-sharded per rank.

  • tp_degree (int) – Tensor parallelism degree (number of ranks). Must be even.

  • num_groups (int) – Number of replica groups for collective communication.

  • force_hbm_cc (bool) – If True, force HBM collective communication path even when SBUF path is feasible.

Returns:

[RANK_N, …], Column-sharded result tensor in shared HBM. Shape depends on communication path (SBUF vs HBM).

Return type:

nl.ndarray

Notes:

  • tp_degree must be even.

  • lhs and rhs must have matching K dimension.

  • M must be divisible by (RANK_N * LNC_N * CHANNEL_N).

  • Platform target is TRN2 only.

Dimensions:

  • m: Local rows per rank (before all-gather).

  • M: Total rows after all-gather (m * tp_degree).

  • K: Shared (contraction) dimension.

This document is relevant for: Trn1, Trn2, Trn3