This document is relevant for: Trn2, Trn3
KV Parallel Segmented Prefill Kernel API Reference#
KV-parallel segmented prefill attention.
Distributes attention computation across ranks, where each rank holds a shard of the KV cache. Uses online softmax to merge partial results.
Background#
The kv_parallel_segmented_prefill kernel implements KV-parallel segmented prefill attention, distributing the attention computation across ranks where each rank holds a shard of the KV cache and using online softmax to merge partial results.
API Reference#
Source code for this kernel API can be found at: kv_parallel_segmented_prefill.py
kv_parallel_segmented_prefill#
- nkilib.core.attention.kv_parallel_segmented_prefill(q: nl.ndarray, k_cache: nl.ndarray, v_cache: nl.ndarray, block_tables: nl.ndarray, kvp_offset: nl.ndarray, replica_groups: ReplicaGroup, group_size: int, block_size: int, seg_size: int, scale: float = 1.0, global_q_offset: int = 0, tp_out: bool = False) nl.ndarray#
KV-parallel segmented prefill attention.
- Parameters:
q (
nl.ndarray) – [BS, S, D], This rank’s Q heads (BS = lnc_degree).k_cache (
nl.ndarray) – [num_blocks, num_kv_heads, block_size, D], Local KV cache (K).v_cache (
nl.ndarray) – [num_blocks, num_kv_heads, block_size, D], Local KV cache (V).block_tables (
nl.ndarray) – [1, max_blocks] int32, Block indices for paged KV.kvp_offset (
nl.ndarray) – [1, 1] int32, Causal mask offset = -rank_id * local_kv_len + global_q_offset.replica_groups (
ReplicaGroup) – ReplicaGroup for collective operations.group_size (
int) – Number of ranks in the replica group.block_size (
int) – KV cache block size.seg_size (
int) – Segment size for attention iteration.scale (
float) – Attention scale factor (default 1.0).global_q_offset (
int) – Global token position of Q token 0 (default 0). Used to compute how many prior KV tokens exist within this rank’s shard for each Q chunk.
- Returns:
[BS, S, D], Merged attention output for this rank’s Q heads.
- Return type:
nl.ndarray
Dimensions:
BS: Batch size (lnc_degree = Q heads per physical rank)
S: Sequence length
D: Head dimension
G: Group size (number of ranks per replica group)
This document is relevant for: Trn2, Trn3