This document is relevant for: Trn1, Trn2, Trn3

Transformer TKG Kernel API Reference#

Implements the transformer token generation forward pass as a single megakernel.

The kernel supports:

  • Configurable number of transformer layers

  • Per-layer attention block (RMSNorm + QKV + RoPE + Attention + Output Projection)

  • Per-layer MLP block (RMSNorm + Gate/Up + Activation + Down Projection)

  • All-reduce collective communication between layers

  • Residual connections

  • Optional FP8 quantization with per-layer weight scales

  • SBUF residual path with SB2SB all-reduce

Background#

The transformer_tkg kernel performs multiple transformer layers in a single kernel invocation for token generation. Within each layer, it executes: attention block, all-reduce, MLP, all-reduce, and residual connections. This reduces kernel launch overhead and enables cross-layer optimizations.

API Reference#

Source code for this kernel API can be found at: transformer_tkg.py

transformer_tkg#

nkilib.experimental.transformer.transformer_tkg(X: nl.ndarray, W_qkvs: List[nl.ndarray], W_outs: List[nl.ndarray], W_gates: List[nl.ndarray], W_ups: List[nl.ndarray], W_downs: List[nl.ndarray], W_gamma_qkvs: List[nl.ndarray], W_gamma_mlps: List[nl.ndarray], K_caches: List[nl.ndarray], V_caches: List[nl.ndarray], RoPE_cos: nl.ndarray, RoPE_sin: nl.ndarray, mask_cache: nl.ndarray, mask_active: nl.ndarray, position_ids: Optional[nl.ndarray], num_layers: int, eps: float = 1e-06, replica_groups: Optional[List[List[int]]] = None, sbuf_residual_and_cc: bool = False, clamp_bound: float = 0.0, W_gate_scales: Optional[List[nl.ndarray]] = None, W_up_scales: Optional[List[nl.ndarray]] = None, W_down_scales: Optional[List[nl.ndarray]] = None)#

Transformer token generation forward pass megakernel.

Parameters:
  • X (nl.ndarray) – [B, S_tkg, H], Input hidden states on HBM

  • W_qkvs (List[nl.ndarray]) – Per-layer QKV projection weights

  • W_outs (List[nl.ndarray]) – Per-layer output projection weights

  • W_gates (List[nl.ndarray]) – Per-layer MLP gate projection weights

  • W_ups (List[nl.ndarray]) – Per-layer MLP up projection weights

  • W_downs (List[nl.ndarray]) – Per-layer MLP down projection weights

  • W_gamma_qkvs (List[nl.ndarray]) – Per-layer RMSNorm gamma for QKV

  • W_gamma_mlps (List[nl.ndarray]) – Per-layer RMSNorm gamma for MLP

  • K_caches (List[nl.ndarray]) – Per-layer K caches on HBM

  • V_caches (List[nl.ndarray]) – Per-layer V caches on HBM

  • RoPE_cos (nl.ndarray) – [d_head//2, B, S_tkg], RoPE cosine embeddings

  • RoPE_sin (nl.ndarray) – [d_head//2, B, S_tkg], RoPE sine embeddings

  • mask_cache (nl.ndarray) – Attention mask for cached KV context

  • mask_active (nl.ndarray) – Attention mask for active tokens

  • position_ids (Optional[nl.ndarray]) – [B, 1], KV cache write positions (None = skip cache update)

  • num_layers (int) – Number of transformer layers to execute

  • eps (float) – RMSNorm epsilon (default 1e-6)

  • replica_groups (Optional[List[List[int]]]) – Replica groups for collective communication

  • sbuf_residual_and_cc (bool) – Use SBUF residual path with SB2SB all-reduce (default False)

  • clamp_bound (float) – FP8 quantization clipping boundary (default 0.0, 0 = no clipping)

  • W_gate_scales (Optional[List[nl.ndarray]]) – Per-layer FP8 gate weight scales

  • W_up_scales (Optional[List[nl.ndarray]]) – Per-layer FP8 up weight scales

  • W_down_scales (Optional[List[nl.ndarray]]) – Per-layer FP8 down weight scales

Returns:

[B, S_tkg, H], Final hidden states after all transformer layers

Return type:

nl.ndarray

Dimensions:

  • B: Batch size

  • S_tkg: Token generation sequence length (number of new tokens)

  • H: Hidden dimension (must be multiple of 128)

  • H0: Partition tile size (pmax = 128)

  • H1: H // H0

This document is relevant for: Trn1, Trn2, Trn3