This document is relevant for: Trn2, Trn3

Attention Block TKG Kernel API Reference#

[Experimental] Implements a fully fused attention block optimized for Token Generation (autoregressive decoding), keeping all intermediate tensors in SBUF to minimize HBM traffic.

The kernel supports:

  • Fused multi-stage computation: pre-normalization, QKV projection, RoPE, post-normalization, attention, KV cache update, and output projection

  • Multiple KV cache layouts: flat (transposed/non-transposed) and block-based

  • Grouped-Query Attention (GQA) with configurable Q/KV head ratios

  • Optional RMSNorm at multiple stages (pre-projection, post-projection per-head)

  • Optional Rotary Position Embedding (RoPE) with configurable layouts

  • Flexible quantization support (FP8, FP16, BF16)

  • FP8 KV cache quantization support

  • Configurable softmax scaling factor

  • Batch processing with per-batch cache indexing

  • Single program multiple data (SPMD) sharding for distributed computation

Background#

The attention_block_tkg kernel combines multiple stages of transformer attention computation into a single fused operation that minimizes data movement between HBM and on-chip memory (SBUF).

Fused Operations:

The kernel fuses the following stages in SBUF to avoid HBM round-trips:

  1. Pre-normalization: Optional RMSNorm on input hidden states

  2. QKV Projection: Linear projection to Query, Key, Value tensors

  3. RoPE: Optional Rotary Position Embedding on Q and K

  4. Post-normalization: Optional per-head RMSNorm on Q and K

  5. Attention Computation: Scaled dot-product attention with KV cache

  6. KV Cache Update: Write new K/V tokens to cache

  7. Output Projection: Linear projection of attention output

Performance Benefits:

By keeping intermediate tensors in SBUF throughout the computation, this kernel achieves:

  • Reduced HBM bandwidth consumption

  • Lower latency for token generation

  • Better hardware utilization through operation fusion

API Reference#

Source code for this kernel API can be found at: attention_block_tkg.py

attention_block_tkg#

nkilib.core.attention_block_tkg.attention_block_tkg.attention_block_tkg(X: nl.ndarray, X_hidden_dim_actual: Optional[int], rmsnorm_X_enabled: bool, rmsnorm_X_eps: Optional[float], rmsnorm_X_gamma: Optional[nl.ndarray], W_qkv: nl.ndarray, bias_qkv: Optional[nl.ndarray], quantization_type_qkv: QuantizationType, weight_dequant_scale_qkv: Optional[nl.ndarray], input_dequant_scale_qkv: Optional[nl.ndarray], rmsnorm_QK_pre_rope_enabled: bool, rmsnorm_QK_pre_rope_eps: float, cos: Optional[nl.ndarray], sin: Optional[nl.ndarray], rope_contiguous_layout: bool, rmsnorm_QK_post_rope_enabled: bool, rmsnorm_QK_post_rope_eps: float, rmsnorm_QK_post_rope_W_Q: Optional[nl.ndarray], rmsnorm_QK_post_rope_W_K: Optional[nl.ndarray], K_cache_transposed: bool, active_blocks_table: Optional[nl.ndarray], K_cache: nl.ndarray, V_cache: nl.ndarray, attention_mask: nl.ndarray, sink: Optional[nl.ndarray], softmax_scale: Optional[float] = None, update_cache: bool, kv_cache_update_idx: Optional[nl.ndarray], k_scale: Optional[nl.ndarray] = None, v_scale: Optional[nl.ndarray] = None, W_out: Optional[nl.ndarray], bias_out: Optional[nl.ndarray], quantization_type_out: QuantizationType, weight_dequant_scale_out: Optional[nl.ndarray], input_dequant_scale_out: Optional[nl.ndarray], transposed_out: bool, out_in_sb: bool, sbm: Optional[SbufManager] = None, skip_attention: bool = False)#

Fused Attention Block for Token Generation (TKG).

Performs end-to-end attention block computation optimized for autoregressive decoding: X → [RMSNorm] → QKV Projection → [RMSNorm Q/K] → [RoPE] → [RMSNorm Q/K] → Attention → KV Cache Update → [Output Projection] → Output

All intermediate tensors remain in SBUF to minimize HBM traffic.

Parameters:
  • X (nl.ndarray) – Input hidden states [B, S_tkg, H] @ HBM or [pmax, B*S_tkg, H//pmax] @ SBUF

  • X_hidden_dim_actual (int, optional) – Actual hidden dim if X is padded

  • rmsnorm_X_enabled (bool) – Apply RMSNorm to X before QKV projection

  • rmsnorm_X_eps (float, optional) – RMSNorm epsilon (default 1e-3)

  • rmsnorm_X_gamma (nl.ndarray, optional) – RMSNorm weights [1, H] @ HBM

  • W_qkv (nl.ndarray) – QKV projection weights [H, d_head*(q_heads+2)] @ HBM

  • bias_qkv (nl.ndarray, optional) – QKV bias [1, d_head*(q_heads+2)] @ HBM

  • quantization_type_qkv (QuantizationType) – Quantization type for QKV projection

  • weight_dequant_scale_qkv (nl.ndarray, optional) – Weight dequantization scale for QKV projection

  • input_dequant_scale_qkv (nl.ndarray, optional) – Input dequantization scale for QKV projection

  • rmsnorm_QK_pre_rope_enabled (bool) – Apply RMSNorm to Q/K before RoPE

  • rmsnorm_QK_pre_rope_eps (float) – Pre-RoPE RMSNorm epsilon

  • cos (nl.ndarray, optional) – RoPE cosine embeddings [d_head//2, B, S_tkg] @ HBM (None = skip RoPE)

  • sin (nl.ndarray, optional) – RoPE sine embeddings [d_head//2, B, S_tkg] @ HBM (None = skip RoPE)

  • rope_contiguous_layout (bool) – True for contiguous halves, False for interleaved

  • rmsnorm_QK_post_rope_enabled (bool) – Apply RMSNorm to Q/K after RoPE

  • rmsnorm_QK_post_rope_eps (float) – Post-RoPE RMSNorm epsilon

  • rmsnorm_QK_post_rope_W_Q (nl.ndarray, optional) – Post-RoPE Q weights [1, d_head] @ HBM

  • rmsnorm_QK_post_rope_W_K (nl.ndarray, optional) – Post-RoPE K weights [1, d_head] @ HBM

  • K_cache_transposed (bool) – K cache layout flag

  • active_blocks_table (nl.ndarray, optional) – Block indices for block KV cache [B, num_blocks] @ HBM

  • K_cache (nl.ndarray) – Key cache @ HBM

  • V_cache (nl.ndarray) – Value cache @ HBM

  • attention_mask (nl.ndarray) – Attention mask [S_ctx, B, q_heads, S_tkg] @ HBM

  • sink (nl.ndarray, optional) – Attention sink tokens [H, 1] @ HBM

  • softmax_scale (float, optional) – Scaling factor for attention scores (Q @ K^T * softmax_scale). If None, defaults to 1.0 / sqrt(d_head).

  • update_cache (bool) – Update KV cache with new tokens

  • kv_cache_update_idx (nl.ndarray, optional) – Cache write positions [B, 1] (uint32_max = skip)

  • k_scale (nl.ndarray, optional) – Key quantization scale for FP8 KV cache. Enables FP8 quantization of K values written to cache.

  • v_scale (nl.ndarray, optional) – Value quantization scale for FP8 KV cache. Enables FP8 quantization of V values written to cache.

  • W_out (nl.ndarray, optional) – Output projection weights [q_heads*d_head, H] @ HBM

  • bias_out (nl.ndarray, optional) – Output projection bias [1, H] @ HBM

  • quantization_type_out (QuantizationType) – Quantization type for output projection

  • weight_dequant_scale_out (nl.ndarray, optional) – Weight dequantization scale for output projection

  • input_dequant_scale_out (nl.ndarray, optional) – Input dequantization scale for output projection

  • transposed_out (bool) – Transpose output layout (requires W_out)

  • out_in_sb (bool) – Return output in SBUF instead of HBM

  • sbm (SbufManager, optional) – SBUF memory manager (otherwise auto-allocated)

  • skip_attention (bool) – Skip attention computation (for testing). Default: False.

Returns:

Tuple of (out, K_out, V_out) - Output tensor, updated K cache or new K tokens, updated V cache or new V tokens

Return type:

tuple

Dimensions:

  • B: batch size

  • S_tkg: number of new tokens to generate

  • S_ctx: KV cache sequence length in current bucket

  • S_max_ctx: maximum KV cache capacity of current bucket

  • H: hidden dimension

  • d_head: head dimension (must be even)

  • q_heads: number of query heads

  • kv_heads: 1 (GQA with single KV head)

Supported Data Types:

  • Supports nl.float16 and nl.bfloat16

Constraints:

  • Requires NeuronCore v3+

  • d_head must be even

  • H must be multiple of 128

  • Requires batch * sequence_tkg * q_heads <= pmax (=128)

Implementation Details#

Computation Flow:

The kernel executes the following stages in sequence:

  1. Input Pre-normalization (optional):

    • Apply RMSNorm to input hidden states: X_norm = RMSNorm(X, rmsnorm_pre_W, rmsnorm_pre_eps)

    • Computed in FP32, result cast back to input dtype

  2. QKV Projection:

    • Compute QKV = X_norm @ W_qkv.T using matrix multiplication

    • Result shape: [B, S_tkg, (q_heads + 2) * d_head]

    • Supports FP8 quantization with dequantization scales

  3. Q/K Processing (per head group):

    • Extract Q heads: Q = QKV[:, :, :q_heads * d_head]

    • Extract K head: K = QKV[:, :, q_heads * d_head : (q_heads + 1) * d_head]

    • Apply RoPE if enabled: Q, K = RoPE(Q, K, cos, sin, position_ids)

    • Apply per-head RMSNorm if enabled: Q = RMSNorm(Q, rmsnorm_post_W_Q), K = RMSNorm(K, rmsnorm_post_W_K)

  4. V Processing:

    • Extract V head: V = QKV[:, :, (q_heads + 1) * d_head :]

  5. KV Cache Update:

    • Write new K/V tokens to cache at positions specified by kv_cache_update_idx

    • Supports multiple cache layouts (flat, transposed, block-based)

    • Uses indirect addressing for efficient batch processing

  6. Attention Computation:

    • Compute scaled dot-product attention: Attn = softmax(Q @ K_cache.T / scale) @ V_cache

    • Apply causal masking based on S_ctx (context lengths)

    • Use FP32 accumulation if mixed_precision=True

    • Supports Grouped-Query Attention by replicating KV heads

  7. Output Projection:

    • Reshape attention output: Attn_flat = Attn.reshape([B, S_tkg, q_heads * d_head])

    • Compute out = Attn_flat @ W_o.T

    • Supports FP8 quantization with dequantization scales

Memory Management:

The kernel uses a custom SBUF memory manager (SbufManager) to efficiently allocate and reuse on-chip memory:

  • Stack-based allocation for temporary tensors

  • Automatic memory reuse after tensor lifetime ends

  • Minimizes SBUF fragmentation

Parallelization:

The kernel supports data parallelism across multiple Neuron Cores:

  • Batch dimension (B) can be sharded across cores

  • Each core processes a subset of batch elements independently

  • KV cache updates use per-core indexing

Cache Layout Support:

  1. Flat Cache (is_block_kv=False):

    • K cache: [B, S_max_ctx, d_head] or [B, d_head, S_max_ctx] (transposed)

    • V cache: [B, S_max_ctx, d_head]

    • Direct indexing by batch and sequence position

  2. Block Cache (is_block_kv=True):

    • K/V cache: [num_blocks, block_len, d_head]

    • Indirect indexing via block slot mapping

    • Efficient for variable-length sequences

Quantization Support:

  • FP8 weights: Provide qkv_scale and o_scale for dequantization

  • Mixed precision: FP32 accumulation with FP16/BF16 inputs

  • Automatic dtype handling throughout the pipeline

Key Implementation Notes:

  1. Grouped-Query Attention: The kernel processes Q heads in groups, where each group shares a single K/V head. This reduces KV cache memory by a factor of q_heads / kv_heads.

  2. RoPE Application: Rotary embeddings are applied using position indices derived from S_ctx (current context length). Supports both contiguous and interleaved layouts.

  3. Causal Masking: Attention scores are masked such that token at position i can only attend to positions 0 to i in the context. Implemented by adding -inf to masked positions before softmax.

  4. Cache Update Optimization:

    • For S_tkg=1: Uses batched vector DMA with vector_offset for all batches in one operation

    • For S_tkg>1: Uses per-batch scalar DMA with scalar_offset

    • Block cache uses indirect addressing via block slot indices

  5. Memory Efficiency: All intermediate tensors (QKV, Q, K, V, attention scores, attention output) remain in SBUF. Only input X, weights, caches, and final output out reside in HBM.

This document is relevant for: Trn2, Trn3