Attention Block TKG Kernel API Reference#
[Experimental] Implements a fully fused attention block optimized for Token Generation (autoregressive decoding), keeping all intermediate tensors in SBUF to minimize HBM traffic.
The kernel supports:
Fused multi-stage computation: pre-normalization, QKV projection, RoPE, post-normalization, attention, KV cache update, and output projection
Multiple KV cache layouts: flat (transposed/non-transposed) and block-based
Grouped-Query Attention (GQA) with configurable Q/KV head ratios
Optional RMSNorm at multiple stages (pre-projection, post-projection per-head)
Optional Rotary Position Embedding (RoPE) with configurable layouts
Flexible quantization support (FP8, FP16, BF16)
FP8 KV cache quantization support
Configurable softmax scaling factor
Batch processing with per-batch cache indexing
Single program multiple data (SPMD) sharding for distributed computation
Background#
The attention_block_tkg kernel combines multiple stages of transformer attention computation into a single fused operation that minimizes data movement between HBM and on-chip memory (SBUF).
Fused Operations:
The kernel fuses the following stages in SBUF to avoid HBM round-trips:
Pre-normalization: Optional RMSNorm on input hidden states
QKV Projection: Linear projection to Query, Key, Value tensors
RoPE: Optional Rotary Position Embedding on Q and K
Post-normalization: Optional per-head RMSNorm on Q and K
Attention Computation: Scaled dot-product attention with KV cache
KV Cache Update: Write new K/V tokens to cache
Output Projection: Linear projection of attention output
Performance Benefits:
By keeping intermediate tensors in SBUF throughout the computation, this kernel achieves:
Reduced HBM bandwidth consumption
Lower latency for token generation
Better hardware utilization through operation fusion
API Reference#
Source code for this kernel API can be found at: attention_block_tkg.py
attention_block_tkg#
- nkilib.core.attention_block_tkg.attention_block_tkg.attention_block_tkg(X: nl.ndarray, X_hidden_dim_actual: Optional[int], rmsnorm_X_enabled: bool, rmsnorm_X_eps: Optional[float], rmsnorm_X_gamma: Optional[nl.ndarray], W_qkv: nl.ndarray, bias_qkv: Optional[nl.ndarray], quantization_type_qkv: QuantizationType, weight_dequant_scale_qkv: Optional[nl.ndarray], input_dequant_scale_qkv: Optional[nl.ndarray], rmsnorm_QK_pre_rope_enabled: bool, rmsnorm_QK_pre_rope_eps: float, cos: Optional[nl.ndarray], sin: Optional[nl.ndarray], rope_contiguous_layout: bool, rmsnorm_QK_post_rope_enabled: bool, rmsnorm_QK_post_rope_eps: float, rmsnorm_QK_post_rope_W_Q: Optional[nl.ndarray], rmsnorm_QK_post_rope_W_K: Optional[nl.ndarray], K_cache_transposed: bool, active_blocks_table: Optional[nl.ndarray], K_cache: nl.ndarray, V_cache: nl.ndarray, attention_mask: nl.ndarray, sink: Optional[nl.ndarray], softmax_scale: Optional[float] = None, update_cache: bool, kv_cache_update_idx: Optional[nl.ndarray], k_scale: Optional[nl.ndarray] = None, v_scale: Optional[nl.ndarray] = None, W_out: Optional[nl.ndarray], bias_out: Optional[nl.ndarray], quantization_type_out: QuantizationType, weight_dequant_scale_out: Optional[nl.ndarray], input_dequant_scale_out: Optional[nl.ndarray], transposed_out: bool, out_in_sb: bool, sbm: Optional[SbufManager] = None, skip_attention: bool = False)#
Fused Attention Block for Token Generation (TKG).
Performs end-to-end attention block computation optimized for autoregressive decoding: X → [RMSNorm] → QKV Projection → [RMSNorm Q/K] → [RoPE] → [RMSNorm Q/K] → Attention → KV Cache Update → [Output Projection] → Output
All intermediate tensors remain in SBUF to minimize HBM traffic.
- Parameters:
X (
nl.ndarray) – Input hidden states[B, S_tkg, H]@ HBM or[pmax, B*S_tkg, H//pmax]@ SBUFX_hidden_dim_actual (
int, optional) – Actual hidden dim if X is paddedrmsnorm_X_enabled (
bool) – Apply RMSNorm to X before QKV projectionrmsnorm_X_eps (
float, optional) – RMSNorm epsilon (default 1e-3)rmsnorm_X_gamma (
nl.ndarray, optional) – RMSNorm weights[1, H]@ HBMW_qkv (
nl.ndarray) – QKV projection weights[H, d_head*(q_heads+2)]@ HBMbias_qkv (
nl.ndarray, optional) – QKV bias[1, d_head*(q_heads+2)]@ HBMquantization_type_qkv (
QuantizationType) – Quantization type for QKV projectionweight_dequant_scale_qkv (
nl.ndarray, optional) – Weight dequantization scale for QKV projectioninput_dequant_scale_qkv (
nl.ndarray, optional) – Input dequantization scale for QKV projectionrmsnorm_QK_pre_rope_enabled (
bool) – Apply RMSNorm to Q/K before RoPErmsnorm_QK_pre_rope_eps (
float) – Pre-RoPE RMSNorm epsiloncos (
nl.ndarray, optional) – RoPE cosine embeddings[d_head//2, B, S_tkg]@ HBM (None = skip RoPE)sin (
nl.ndarray, optional) – RoPE sine embeddings[d_head//2, B, S_tkg]@ HBM (None = skip RoPE)rope_contiguous_layout (
bool) – True for contiguous halves, False for interleavedrmsnorm_QK_post_rope_enabled (
bool) – Apply RMSNorm to Q/K after RoPErmsnorm_QK_post_rope_eps (
float) – Post-RoPE RMSNorm epsilonrmsnorm_QK_post_rope_W_Q (
nl.ndarray, optional) – Post-RoPE Q weights[1, d_head]@ HBMrmsnorm_QK_post_rope_W_K (
nl.ndarray, optional) – Post-RoPE K weights[1, d_head]@ HBMK_cache_transposed (
bool) – K cache layout flagactive_blocks_table (
nl.ndarray, optional) – Block indices for block KV cache[B, num_blocks]@ HBMK_cache (
nl.ndarray) – Key cache @ HBMV_cache (
nl.ndarray) – Value cache @ HBMattention_mask (
nl.ndarray) – Attention mask[S_ctx, B, q_heads, S_tkg]@ HBMsink (
nl.ndarray, optional) – Attention sink tokens[H, 1]@ HBMsoftmax_scale (
float, optional) – Scaling factor for attention scores (Q @ K^T * softmax_scale). IfNone, defaults to1.0 / sqrt(d_head).update_cache (
bool) – Update KV cache with new tokenskv_cache_update_idx (
nl.ndarray, optional) – Cache write positions[B, 1](uint32_max = skip)k_scale (
nl.ndarray, optional) – Key quantization scale for FP8 KV cache. Enables FP8 quantization of K values written to cache.v_scale (
nl.ndarray, optional) – Value quantization scale for FP8 KV cache. Enables FP8 quantization of V values written to cache.W_out (
nl.ndarray, optional) – Output projection weights[q_heads*d_head, H]@ HBMbias_out (
nl.ndarray, optional) – Output projection bias[1, H]@ HBMquantization_type_out (
QuantizationType) – Quantization type for output projectionweight_dequant_scale_out (
nl.ndarray, optional) – Weight dequantization scale for output projectioninput_dequant_scale_out (
nl.ndarray, optional) – Input dequantization scale for output projectiontransposed_out (
bool) – Transpose output layout (requires W_out)out_in_sb (
bool) – Return output in SBUF instead of HBMsbm (
SbufManager, optional) – SBUF memory manager (otherwise auto-allocated)skip_attention (
bool) – Skip attention computation (for testing). Default: False.
- Returns:
Tuple of (out, K_out, V_out) - Output tensor, updated K cache or new K tokens, updated V cache or new V tokens
- Return type:
tuple
Dimensions:
B: batch size
S_tkg: number of new tokens to generate
S_ctx: KV cache sequence length in current bucket
S_max_ctx: maximum KV cache capacity of current bucket
H: hidden dimension
d_head: head dimension (must be even)
q_heads: number of query heads
kv_heads: 1 (GQA with single KV head)
Supported Data Types:
Supports nl.float16 and nl.bfloat16
Constraints:
Requires NeuronCore v3+
d_head must be even
H must be multiple of 128
Requires
batch * sequence_tkg * q_heads <= pmax (=128)
Implementation Details#
Computation Flow:
The kernel executes the following stages in sequence:
Input Pre-normalization (optional):
Apply RMSNorm to input hidden states:
X_norm = RMSNorm(X, rmsnorm_pre_W, rmsnorm_pre_eps)Computed in FP32, result cast back to input dtype
QKV Projection:
Compute
QKV = X_norm @ W_qkv.Tusing matrix multiplicationResult shape:
[B, S_tkg, (q_heads + 2) * d_head]Supports FP8 quantization with dequantization scales
Q/K Processing (per head group):
Extract Q heads:
Q = QKV[:, :, :q_heads * d_head]Extract K head:
K = QKV[:, :, q_heads * d_head : (q_heads + 1) * d_head]Apply RoPE if enabled:
Q, K = RoPE(Q, K, cos, sin, position_ids)Apply per-head RMSNorm if enabled:
Q = RMSNorm(Q, rmsnorm_post_W_Q),K = RMSNorm(K, rmsnorm_post_W_K)
V Processing:
Extract V head:
V = QKV[:, :, (q_heads + 1) * d_head :]
KV Cache Update:
Write new K/V tokens to cache at positions specified by
kv_cache_update_idxSupports multiple cache layouts (flat, transposed, block-based)
Uses indirect addressing for efficient batch processing
Attention Computation:
Compute scaled dot-product attention:
Attn = softmax(Q @ K_cache.T / scale) @ V_cacheApply causal masking based on
S_ctx(context lengths)Use FP32 accumulation if
mixed_precision=TrueSupports Grouped-Query Attention by replicating KV heads
Output Projection:
Reshape attention output:
Attn_flat = Attn.reshape([B, S_tkg, q_heads * d_head])Compute
out = Attn_flat @ W_o.TSupports FP8 quantization with dequantization scales
Memory Management:
The kernel uses a custom SBUF memory manager (SbufManager) to efficiently allocate and reuse on-chip memory:
Stack-based allocation for temporary tensors
Automatic memory reuse after tensor lifetime ends
Minimizes SBUF fragmentation
Parallelization:
The kernel supports data parallelism across multiple Neuron Cores:
Batch dimension (
B) can be sharded across coresEach core processes a subset of batch elements independently
KV cache updates use per-core indexing
Cache Layout Support:
Flat Cache (
is_block_kv=False):K cache:
[B, S_max_ctx, d_head]or[B, d_head, S_max_ctx](transposed)V cache:
[B, S_max_ctx, d_head]Direct indexing by batch and sequence position
Block Cache (
is_block_kv=True):K/V cache:
[num_blocks, block_len, d_head]Indirect indexing via block slot mapping
Efficient for variable-length sequences
Quantization Support:
FP8 weights: Provide
qkv_scaleando_scalefor dequantizationMixed precision: FP32 accumulation with FP16/BF16 inputs
Automatic dtype handling throughout the pipeline
Key Implementation Notes:
Grouped-Query Attention: The kernel processes Q heads in groups, where each group shares a single K/V head. This reduces KV cache memory by a factor of
q_heads / kv_heads.RoPE Application: Rotary embeddings are applied using position indices derived from
S_ctx(current context length). Supports both contiguous and interleaved layouts.Causal Masking: Attention scores are masked such that token at position
ican only attend to positions0toiin the context. Implemented by adding-infto masked positions before softmax.Cache Update Optimization:
For
S_tkg=1: Uses batched vector DMA withvector_offsetfor all batches in one operationFor
S_tkg>1: Uses per-batch scalar DMA withscalar_offsetBlock cache uses indirect addressing via block slot indices
Memory Efficiency: All intermediate tensors (QKV, Q, K, V, attention scores, attention output) remain in SBUF. Only input
X, weights, caches, and final outputoutreside in HBM.