QKV Kernel API Reference#
Performs Query-Key-Value projection with optional normalization and RoPE fusion.
The kernel supports:
Optional RMSNorm/LayerNorm fusion
Multiple output tensor layouts
Residual connections from previous MLP and attention outputs
Automatic selection between TKG and CTE implementations based on batch_size * seqlen threshold
Optional RoPE (Rotary Position Embedding) fusion
Fused FP8 KV cache quantization
Block-based KV cache layout support
MX quantization support (CTE mode only)
Background#
The QKV kernel is a critical component in transformer architectures, responsible for projecting the input hidden states into query, key, and value representations. This kernel optimizes the projection operation by fusing it with optional normalization and supporting various output layouts to accommodate different transformer implementations.
Note
This kernel automatically selects between TKG (Token Generation) and CTE (Context Encoding) implementations based on sequence length, ensuring optimal performance across different use cases. CTE is used for longer sequences, while TKG is optimized for shorter sequences.
API Reference#
Source code for this kernel API can be found at: qkv.py
qkv#
- nkilib.core.qkv.qkv(input: nl.ndarray, fused_qkv_weights: nl.ndarray, output_layout: QKVOutputLayout = QKVOutputLayout.BSD, bias: Optional[nl.ndarray] = None, quantization_type: QuantizationType = QuantizationType.NONE, qkv_w_scale: Optional[nl.ndarray] = None, qkv_in_scale: Optional[nl.ndarray] = None, fused_residual_add: Optional[bool] = False, mlp_prev: Optional[nl.ndarray] = None, attention_prev: Optional[nl.ndarray] = None, fused_norm_type: NormType = NormType.NO_NORM, gamma_norm_weights: Optional[nl.ndarray] = None, layer_norm_bias: Optional[nl.ndarray] = None, norm_eps: float = 1e-6, hidden_actual: Optional[int] = None, fused_rope: Optional[bool] = False, cos_cache: Optional[nl.ndarray] = None, sin_cache: Optional[nl.ndarray] = None, d_head: Optional[int] = None, num_q_heads: Optional[int] = None, num_kv_heads: Optional[int] = None, k_cache: Optional[nl.ndarray] = None, v_cache: Optional[nl.ndarray] = None, k_scale: Optional[nl.ndarray] = None, v_scale: Optional[nl.ndarray] = None, fp8_max: Optional[float] = None, fp8_min: Optional[float] = None, kv_dtype: Optional[type] = None, use_block_kv: bool = False, block_size: Optional[int] = None, slot_mapping: Optional[nl.ndarray] = None, store_output_in_sbuf: bool = False, sbm: Optional[SbufManager] = None, use_auto_allocation: bool = False, load_input_with_DMA_transpose: bool = True, is_input_swizzled: bool = False) nl.ndarray#
QKV (Query, Key, Value) projection kernel with multiple optional fused operations.
Performs matrix multiplication between hidden states and fused QKV weights matrix with optional fused operations including residual addition, normalization, bias addition, and RoPE rotation. Automatically selects between TKG and CTE implementations based on sequence length.
- Parameters:
input (
nl.ndarray) – Input hidden states tensor. Shape: [B, S, H] where B=batch, S=sequence_length, H=hidden_dim.fused_qkv_weights (
nl.ndarray) – Fused QKV weight matrix. Shape: [H, I] where I=fused_qkv_dim=(num_q_heads + 2*num_kv_heads)*d_head.output_layout (
QKVOutputLayout) – Output tensor layout. QKVOutputLayout.BSD=[B, S, I] or QKVOutputLayout.NBSd=[num_heads, B, S, d_head]. Default: QKVOutputLayout.BSD.bias (
nl.ndarray, optional) – Bias tensor to add to QKV projection output. Shape: [1, I].quantization_type (
QuantizationType) – Type of quantization to apply. Default: QuantizationType.NONE.qkv_w_scale (
nl.ndarray, optional) – Weight scale tensor for quantization.qkv_in_scale (
nl.ndarray, optional) – Input scale tensor for quantization.fused_residual_add (
bool, optional) – Whether to perform residual addition: input = input + mlp_prev + attention_prev. Default: False.mlp_prev (
nl.ndarray, optional) – Previous MLP output tensor for residual addition. Shape: [B, S, H].attention_prev (
nl.ndarray, optional) – Previous attention output tensor for residual addition. Shape: [B, S, H].fused_norm_type (
NormType) – Type of normalization (NO_NORM, RMS_NORM, RMS_NORM_SKIP_GAMMA, LAYER_NORM). Default: NormType.NO_NORM.gamma_norm_weights (
nl.ndarray, optional) – Normalization gamma/scale weights. Shape: [1, H]. Required for RMS_NORM and LAYER_NORM.layer_norm_bias (
nl.ndarray, optional) – Layer normalization beta/bias weights. Shape: [1, H]. Only for LAYER_NORM.norm_eps (
float, optional) – Epsilon value for numerical stability in normalization. Default: 1e-6.hidden_actual (
int, optional) – Actual hidden dimension for padded tensors (if H contains padding).fused_rope (
bool, optional) – Whether to apply RoPE rotation to Query and Key heads after QKV projection. Default: False.cos_cache (
nl.ndarray, optional) – Cosine cache for RoPE. Shape: [B, S, d_head]. Required if fused_rope=True.sin_cache (
nl.ndarray, optional) – Sine cache for RoPE. Shape: [B, S, d_head]. Required if fused_rope=True.d_head (
int, optional) – Dimension per attention head. Required for QKVOutputLayout.NBSd and RoPE.num_q_heads (
int, optional) – Number of query heads. Required for RoPE.num_kv_heads (
int, optional) – Number of key/value heads. Required for RoPE.k_cache (
nl.ndarray, optional) – Key cache tensor for fused FP8 KV cache quantization. Shape:[B, max_seq_len, kv_dim]. Required whenk_scaleandv_scaleare provided.v_cache (
nl.ndarray, optional) – Value cache tensor for fused FP8 KV cache quantization. Shape:[B, max_seq_len, kv_dim]. Required whenk_scaleandv_scaleare provided.k_scale (
nl.ndarray, optional) – Key quantization scale for FP8 KV cache quantization. Enables KV output quantization when bothk_scaleandv_scaleare provided.v_scale (
nl.ndarray, optional) – Value quantization scale for FP8 KV cache quantization. Enables KV output quantization when bothk_scaleandv_scaleare provided.fp8_max (
float, optional) – Maximum FP8 value for clamping during KV cache quantization. Defaults to the maximum positive value ofkv_dtype.fp8_min (
float, optional) – Minimum FP8 value for clamping during KV cache quantization. Defaults to the negative offp8_max.kv_dtype (
type, optional) – Data type for quantized KV cache output. Defaults to the input tensor dtype if not specified.use_block_kv (
bool) – Whether to use block-based KV cache layout. WhenTrue, requiresblock_sizeandslot_mapping. Default: False.block_size (
int, optional) – Number of tokens per block in block KV cache. Required whenuse_block_kv=True.slot_mapping (
nl.ndarray, optional) – Mapping from token positions to block slots for block KV cache. Required whenuse_block_kv=True.store_output_in_sbuf (
bool) – Whether to store output in SBUF (currently unsupported, must be False). Default: False.sbm (
SbufManager, optional) – Optional SBUF manager for memory allocation control with pre-specified bounds for SBUF usage.use_auto_allocation (
bool) – Whether to use automatic SBUF allocation. Default: False.load_input_with_DMA_transpose (
bool) – Whether to use DMA transpose optimization. Default: True.is_input_swizzled (
bool) – Whether the input tensor is swizzled (only applicable with MX Quantization). Default: False.
- Returns:
QKV projection output tensor with shape determined by output_layout.
- Return type:
nl.ndarray
Raises:
ValueError – Raised when contract dimension mismatch occurs between
inputandfused_qkv_weights.AssertionError – Raised when required parameters for fused operations are missing or have incorrect shapes.
Implementation Details#
The kernel implementation includes several key optimizations:
Automatic Implementation Selection: The kernel automatically selects between TKG (Token Generation) and CTE (Context Encoding) implementations based on sequence length. Some features like RoPE fusion and loading input with DMA transpose are only available in CTE mode. TKG mode only supports automatic allocation at the moment.
Fused Operations Support:
Residual Addition: Fuses
input+mlp_prev+attention_prevNormalization: Supports RMSNorm, LayerNorm, and
RMS_NORM_SKIP_GAMMABias Addition: Adds bias to QKV projection output
RoPE Fusion: Applies Rotary Position Embedding to Query and Key heads
FP8 KV Cache Quantization: Quantizes K and V outputs directly into the KV cache, avoiding a separate quantization step. Enabled when
k_scaleandv_scaleare provided. Only supported with BSD output layout.Block KV Cache: Supports block-based KV cache layout with indirect addressing via
slot_mappingfor variable-length sequences.
Flexible Output Layouts: Supports BSD (
[B, S, I]) and NBSd ([num_heads, B, S, d_head]) output tensor layouts.Memory Management:
Optional SBUF manager for controlled memory allocation
DMA transpose optimization for weight loading
Hardware Compatibility: Supports bf16, fp16, and fp32 data types (fp32 inputs are internally converted to bf16).
Constraints:
H must be ≤ 24576 and divisible by 128
I must be ≤ 4096
For NBSd output: d_head must equal 128
FP8 KV cache quantization requires BSD output layout
Block KV cache requires
block_sizeandslot_mapping