This document is relevant for: Trn1, Trn2, Trn3
Transformer TKG Kernel API Reference#
Implements the transformer token generation forward pass as a single megakernel.
The kernel supports:
Configurable number of transformer layers
Per-layer attention block (RMSNorm + QKV + RoPE + Attention + Output Projection)
Per-layer MLP block (RMSNorm + Gate/Up + Activation + Down Projection)
All-reduce collective communication between layers
Residual connections
Optional FP8 quantization with per-layer weight scales
SBUF residual path with SB2SB all-reduce
Background#
The transformer_tkg kernel performs multiple transformer layers in a single kernel invocation for token generation. Within each layer, it executes: attention block, all-reduce, MLP, all-reduce, and residual connections. This reduces kernel launch overhead and enables cross-layer optimizations.
API Reference#
Source code for this kernel API can be found at: transformer_tkg.py
transformer_tkg#
- nkilib.experimental.transformer.transformer_tkg(X: nl.ndarray, W_qkvs: List[nl.ndarray], W_outs: List[nl.ndarray], W_gates: List[nl.ndarray], W_ups: List[nl.ndarray], W_downs: List[nl.ndarray], W_gamma_qkvs: List[nl.ndarray], W_gamma_mlps: List[nl.ndarray], K_caches: List[nl.ndarray], V_caches: List[nl.ndarray], RoPE_cos: nl.ndarray, RoPE_sin: nl.ndarray, mask_cache: nl.ndarray, mask_active: nl.ndarray, position_ids: Optional[nl.ndarray], num_layers: int, eps: float = 1e-06, replica_groups: Optional[List[List[int]]] = None, sbuf_residual_and_cc: bool = False, clamp_bound: float = 0.0, W_gate_scales: Optional[List[nl.ndarray]] = None, W_up_scales: Optional[List[nl.ndarray]] = None, W_down_scales: Optional[List[nl.ndarray]] = None)#
Transformer token generation forward pass megakernel.
- Parameters:
X (
nl.ndarray) – [B, S_tkg, H], Input hidden states on HBMW_qkvs (
List[nl.ndarray]) – Per-layer QKV projection weightsW_outs (
List[nl.ndarray]) – Per-layer output projection weightsW_gates (
List[nl.ndarray]) – Per-layer MLP gate projection weightsW_ups (
List[nl.ndarray]) – Per-layer MLP up projection weightsW_downs (
List[nl.ndarray]) – Per-layer MLP down projection weightsW_gamma_qkvs (
List[nl.ndarray]) – Per-layer RMSNorm gamma for QKVW_gamma_mlps (
List[nl.ndarray]) – Per-layer RMSNorm gamma for MLPK_caches (
List[nl.ndarray]) – Per-layer K caches on HBMV_caches (
List[nl.ndarray]) – Per-layer V caches on HBMRoPE_cos (
nl.ndarray) – [d_head//2, B, S_tkg], RoPE cosine embeddingsRoPE_sin (
nl.ndarray) – [d_head//2, B, S_tkg], RoPE sine embeddingsmask_cache (
nl.ndarray) – Attention mask for cached KV contextmask_active (
nl.ndarray) – Attention mask for active tokensposition_ids (
Optional[nl.ndarray]) – [B, 1], KV cache write positions (None = skip cache update)num_layers (
int) – Number of transformer layers to executeeps (
float) – RMSNorm epsilon (default 1e-6)replica_groups (
Optional[List[List[int]]]) – Replica groups for collective communicationsbuf_residual_and_cc (
bool) – Use SBUF residual path with SB2SB all-reduce (default False)clamp_bound (
float) – FP8 quantization clipping boundary (default 0.0, 0 = no clipping)W_gate_scales (
Optional[List[nl.ndarray]]) – Per-layer FP8 gate weight scalesW_up_scales (
Optional[List[nl.ndarray]]) – Per-layer FP8 up weight scalesW_down_scales (
Optional[List[nl.ndarray]]) – Per-layer FP8 down weight scales
- Returns:
[B, S_tkg, H], Final hidden states after all transformer layers
- Return type:
nl.ndarray
Dimensions:
B: Batch size
S_tkg: Token generation sequence length (number of new tokens)
H: Hidden dimension (must be multiple of 128)
H0: Partition tile size (pmax = 128)
H1: H // H0
This document is relevant for: Trn1, Trn2, Trn3