This document is relevant for: Trn2, Trn3

KV Parallel Segmented Prefill Kernel API Reference#

KV-parallel segmented prefill attention.

Distributes attention computation across ranks, where each rank holds a shard of the KV cache. Uses online softmax to merge partial results.

Background#

The kv_parallel_segmented_prefill kernel implements KV-parallel segmented prefill attention, distributing the attention computation across ranks where each rank holds a shard of the KV cache and using online softmax to merge partial results.

API Reference#

Source code for this kernel API can be found at: kv_parallel_segmented_prefill.py

kv_parallel_segmented_prefill#

nkilib.core.attention.kv_parallel_segmented_prefill(q: nl.ndarray, k_cache: nl.ndarray, v_cache: nl.ndarray, block_tables: nl.ndarray, kvp_offset: nl.ndarray, replica_groups: ReplicaGroup, group_size: int, block_size: int, seg_size: int, scale: float = 1.0, global_q_offset: int = 0, tp_out: bool = False) nl.ndarray#

KV-parallel segmented prefill attention.

Parameters:
  • q (nl.ndarray) – [BS, S, D], This rank’s Q heads (BS = lnc_degree).

  • k_cache (nl.ndarray) – [num_blocks, num_kv_heads, block_size, D], Local KV cache (K).

  • v_cache (nl.ndarray) – [num_blocks, num_kv_heads, block_size, D], Local KV cache (V).

  • block_tables (nl.ndarray) – [1, max_blocks] int32, Block indices for paged KV.

  • kvp_offset (nl.ndarray) – [1, 1] int32, Causal mask offset = -rank_id * local_kv_len + global_q_offset.

  • replica_groups (ReplicaGroup) – ReplicaGroup for collective operations.

  • group_size (int) – Number of ranks in the replica group.

  • block_size (int) – KV cache block size.

  • seg_size (int) – Segment size for attention iteration.

  • scale (float) – Attention scale factor (default 1.0).

  • global_q_offset (int) – Global token position of Q token 0 (default 0). Used to compute how many prior KV tokens exist within this rank’s shard for each Q chunk.

Returns:

[BS, S, D], Merged attention output for this rank’s Q heads.

Return type:

nl.ndarray

Dimensions:

  • BS: Batch size (lnc_degree = Q heads per physical rank)

  • S: Sequence length

  • D: Head dimension

  • G: Group size (number of ranks per replica group)

This document is relevant for: Trn2, Trn3