This document is relevant for: Trn2, Trn3

Attention Segmented CTE Kernel API Reference#

Segmented attention computation with block-based KV cache and prefix caching support optimized for Context Encoding.

Background#

The attention_segmented_cte kernel implements segmented attention that processes the KV cache in configurable segments, supporting block-based KV cache layout and prefix caching. It includes helper kernels for floor/ceil operations needed for position-based computations and a KV cache loader that reads blocks from page tables into SBUF.

API Reference#

Source code for this kernel API can be found at: attention_segmented_cte.py

floor_nisa_kernel#

nkilib.core.attention.floor_nisa_kernel(src_t: nl.ndarray, dst_t: nl.ndarray, p_size: int, f_size: int, allocator: ModularAllocator)#

NISA implementation for floor operation using integer casting.

Parameters:
  • src_t (nl.ndarray) – Source tensor to compute floor of (dtype: fp32)

  • dst_t (nl.ndarray) – Destination tensor for floor result (dtype: int32)

  • p_size (int) – First dimension size

  • f_size (int) – Second dimension size

  • allocator (ModularAllocator) – SBUF allocator for temporary tensors

ceil_nisa_kernel#

nkilib.core.attention.ceil_nisa_kernel(src_t: nl.ndarray, dst_t: nl.ndarray, p_size: int, f_size: int, allocator: ModularAllocator)#

NISA implementation for ceil operation using floor.

Parameters:
  • src_t (nl.ndarray) – Source tensor to compute ceil of (dtype: fp32)

  • dst_t (nl.ndarray) – Destination tensor for ceil result (dtype: int32)

  • p_size (int) – First dimension size

  • f_size (int) – Second dimension size

  • allocator (ModularAllocator) – SBUF allocator for temporary tensors

load_kv_cache#

nkilib.core.attention.load_kv_cache(k_cache, v_cache, block_tables, k_sbuf, v_sbuf, b_i, h_i, block_table_offset, num_blocks, allocator: ModularAllocator, k_pre_transposed: bool = False)#

Load KV cache from block tables to SBUF for a single KV head.

Parameters:
  • k_cache – K cache in HBM. Shape depends on k_pre_transposed: - False: (num_blocks_total, num_kv_head, block_size, head_dim) - True: (num_blocks_total * num_kv_head, head_dim, block_size)

  • v_cache – V cache in HBM with shape (num_blocks_total, num_kv_head, block_size, head_dim)

  • block_tables – Block table tensor with shape (batch_size, max_blocks_per_seq)

  • k_sbuf – K SBUF tiles to load into

  • v_sbuf – V SBUF tiles to load into

  • b_i – Current sequence index in batch

  • h_i – Current KV head index

  • block_table_offset – SBUF tensor (1, 1) indicating the block offset for the current segment

  • num_blocks – Number of blocks to load

  • allocator (ModularAllocator) – SBUF allocator for temporary tensor allocation

  • k_pre_transposed (bool) – If True, K cache is already stored in transposed layout (head_dim, block_size) per block, so no transpose is needed during loading.

attention_segmented_cte#

nkilib.core.attention.attention_segmented_cte(q: nl.ndarray, k_cache: nl.ndarray, v_cache: nl.ndarray, block_tables: nl.ndarray, prior_tokens: nl.ndarray, block_size: int, prior_seg_size: int, scale: float = 1.0, tp_q: bool = True, tp_out: bool = False, sliding_window: Optional[int] = None, sink: Optional[nl.ndarray] = None, num_q_heads: int = 1, kvp_offset: Optional[nl.ndarray] = None, k_pre_transposed: bool = False, k_scale: Optional[nl.ndarray] = None, v_scale: Optional[nl.ndarray] = None)#

Segmented attention computation with block-based KV cache and prefix caching.

Parameters:
  • q (nl.ndarray) – Query tensor with shape (batch_size, seqlen_q, d) when tp_q=True

  • k_cache (nl.ndarray) – K cache in HBM. Shape depends on k_pre_transposed: - False: (num_blocks, num_kv_head, block_size, head_dim) - True: (num_blocks * num_kv_head, head_dim, block_size)

  • v_cache (nl.ndarray) – V cache in HBM with shape (num_blocks, num_kv_head, block_size, head_dim)

  • block_tables (nl.ndarray) – Block table tensor with shape (batch_size, max_blocks_per_seq). May contain -1 values for padding (triggers DMA skipping). If prior_last_segment_tokens < prior_seg_size, caller should prepend -1 padding.

  • prior_tokens (nl.ndarray) – Total number of prior (cached) tokens, shape (1, 1). Must be multiple of block_size.

  • block_size (int) – Size of each block in the KV cache

  • prior_seg_size (int) – Size of each KV segment to process iteratively

  • scale (float) – Scaling factor for attention scores (default 1.0)

  • tp_q (bool) – Query tensor transpose flag (default True)

  • tp_out (bool) – Output tensor transpose flag (default False)

  • k_pre_transposed (bool) – If True, K cache is already stored in transposed layout (head_dim, block_size) per block, written by _quantize_and_store_k_transposed in qkv_cte.

  • k_scale (Optional[nl.ndarray]) – Optional per-head-dim dequantization scale for K cache, shape (128, 1). When provided, Q is scaled by k_scale before QK^T matmul (delayed dequant).

  • v_scale (Optional[nl.ndarray]) – Optional per-head-dim dequantization scale for V cache, shape (128, 1). When provided, the output is scaled by v_scale after PV matmul normalization.

This document is relevant for: Trn2, Trn3