This document is relevant for: Trn2, Trn3
Attention Segmented CTE Kernel API Reference#
Segmented attention computation with block-based KV cache and prefix caching support optimized for Context Encoding.
Background#
The attention_segmented_cte kernel implements segmented attention that processes the KV cache in configurable segments, supporting block-based KV cache layout and prefix caching. It includes helper kernels for floor/ceil operations needed for position-based computations and a KV cache loader that reads blocks from page tables into SBUF.
API Reference#
Source code for this kernel API can be found at: attention_segmented_cte.py
floor_nisa_kernel#
- nkilib.core.attention.floor_nisa_kernel(src_t: nl.ndarray, dst_t: nl.ndarray, p_size: int, f_size: int, allocator: ModularAllocator)#
NISA implementation for floor operation using integer casting.
- Parameters:
src_t (
nl.ndarray) – Source tensor to compute floor of (dtype: fp32)dst_t (
nl.ndarray) – Destination tensor for floor result (dtype: int32)p_size (
int) – First dimension sizef_size (
int) – Second dimension sizeallocator (
ModularAllocator) – SBUF allocator for temporary tensors
ceil_nisa_kernel#
- nkilib.core.attention.ceil_nisa_kernel(src_t: nl.ndarray, dst_t: nl.ndarray, p_size: int, f_size: int, allocator: ModularAllocator)#
NISA implementation for ceil operation using floor.
- Parameters:
src_t (
nl.ndarray) – Source tensor to compute ceil of (dtype: fp32)dst_t (
nl.ndarray) – Destination tensor for ceil result (dtype: int32)p_size (
int) – First dimension sizef_size (
int) – Second dimension sizeallocator (
ModularAllocator) – SBUF allocator for temporary tensors
load_kv_cache#
- nkilib.core.attention.load_kv_cache(k_cache, v_cache, block_tables, k_sbuf, v_sbuf, b_i, h_i, block_table_offset, num_blocks, allocator: ModularAllocator, k_pre_transposed: bool = False)#
Load KV cache from block tables to SBUF for a single KV head.
- Parameters:
k_cache – K cache in HBM. Shape depends on k_pre_transposed: - False: (num_blocks_total, num_kv_head, block_size, head_dim) - True: (num_blocks_total * num_kv_head, head_dim, block_size)
v_cache – V cache in HBM with shape (num_blocks_total, num_kv_head, block_size, head_dim)
block_tables – Block table tensor with shape (batch_size, max_blocks_per_seq)
k_sbuf – K SBUF tiles to load into
v_sbuf – V SBUF tiles to load into
b_i – Current sequence index in batch
h_i – Current KV head index
block_table_offset – SBUF tensor (1, 1) indicating the block offset for the current segment
num_blocks – Number of blocks to load
allocator (
ModularAllocator) – SBUF allocator for temporary tensor allocationk_pre_transposed (
bool) – If True, K cache is already stored in transposed layout (head_dim, block_size) per block, so no transpose is needed during loading.
attention_segmented_cte#
- nkilib.core.attention.attention_segmented_cte(q: nl.ndarray, k_cache: nl.ndarray, v_cache: nl.ndarray, block_tables: nl.ndarray, prior_tokens: nl.ndarray, block_size: int, prior_seg_size: int, scale: float = 1.0, tp_q: bool = True, tp_out: bool = False, sliding_window: Optional[int] = None, sink: Optional[nl.ndarray] = None, num_q_heads: int = 1, kvp_offset: Optional[nl.ndarray] = None, k_pre_transposed: bool = False, k_scale: Optional[nl.ndarray] = None, v_scale: Optional[nl.ndarray] = None)#
Segmented attention computation with block-based KV cache and prefix caching.
- Parameters:
q (
nl.ndarray) – Query tensor with shape (batch_size, seqlen_q, d) when tp_q=Truek_cache (
nl.ndarray) – K cache in HBM. Shape depends on k_pre_transposed: - False: (num_blocks, num_kv_head, block_size, head_dim) - True: (num_blocks * num_kv_head, head_dim, block_size)v_cache (
nl.ndarray) – V cache in HBM with shape (num_blocks, num_kv_head, block_size, head_dim)block_tables (
nl.ndarray) – Block table tensor with shape (batch_size, max_blocks_per_seq). May contain -1 values for padding (triggers DMA skipping). If prior_last_segment_tokens < prior_seg_size, caller should prepend -1 padding.prior_tokens (
nl.ndarray) – Total number of prior (cached) tokens, shape (1, 1). Must be multiple of block_size.block_size (
int) – Size of each block in the KV cacheprior_seg_size (
int) – Size of each KV segment to process iterativelyscale (
float) – Scaling factor for attention scores (default 1.0)tp_q (
bool) – Query tensor transpose flag (default True)tp_out (
bool) – Output tensor transpose flag (default False)k_pre_transposed (
bool) – If True, K cache is already stored in transposed layout (head_dim, block_size) per block, written by _quantize_and_store_k_transposed in qkv_cte.k_scale (
Optional[nl.ndarray]) – Optional per-head-dim dequantization scale for K cache, shape (128, 1). When provided, Q is scaled by k_scale before QK^T matmul (delayed dequant).v_scale (
Optional[nl.ndarray]) – Optional per-head-dim dequantization scale for V cache, shape (128, 1). When provided, the output is scaled by v_scale after PV matmul normalization.
This document is relevant for: Trn2, Trn3