This document is relevant for: Trn2, Trn3
Ring Attention Bwd Kernel API Reference#
Ring Attention Backward SPMD kernel.
Computes gradients dQ, dK, dV for ring attention using collective permute operations to circulate Q, dY, LSE, and dy_o_sum across workers while keeping K, V local. Supports causal masking and striped attention.
Background#
The ring_attention_spmd_bwd kernel computes the backward pass for ring attention, producing gradients dQ, dK, and dV using collective permute operations to circulate tensors across workers while keeping K and V local to each worker.
API Reference#
Source code for this kernel API can be found at: ring_attention_bwd.py
ring_attention_spmd_bwd#
- nkilib.experimental.attention.ring_attention_spmd_bwd(q_ref: nl.ndarray, k_ref: nl.ndarray, v_ref: nl.ndarray, o_ref: nl.ndarray, dy_ref: nl.ndarray, lse_ref: nl.ndarray, use_causal_mask: bool = False, mixed_precision: bool = True, softmax_scale: float = None, num_workers: int = 1, lnc_size: int = 1, replica_groups: tuple = None, striped_attention: bool = False)#
Ring Attention Backward SPMD kernel.
- Parameters:
q_ref (
nl.ndarray) – [B, N, D, S], Query tensor in HBM.k_ref (
nl.ndarray) – [B, N, D, S], Key tensor in HBM.v_ref (
nl.ndarray) – [B, N, D, S], Value tensor in HBM.o_ref (
nl.ndarray) – [B, N, D, S], Forward output tensor in HBM.dy_ref (
nl.ndarray) – [B, N, D, S], Upstream gradient tensor in HBM.lse_ref (
nl.ndarray) – [B, N, 128, S//128], Log-sum-exp from forward pass in HBM.use_causal_mask (
bool) – Whether to apply causal masking. Default: False.mixed_precision (
bool) – Whether to use mixed precision (fp32 accumulators). Default: True.softmax_scale (
float) – Softmax scale factor. Default: 1/sqrt(D).num_workers (
int) – Number of workers in the ring. Default: 1.lnc_size (
int) – LNC size (number of logical cores). Default: 1.replica_groups (
tuple) – Replica groups for collective communication. Default: None.striped_attention (
bool) – Whether to use striped attention layout. Default: False.
- Returns:
[B, N, D, S], Query gradient in HBM (float32).
- Return type:
nl.ndarray- Returns:
[B, N, D, S], Key gradient in HBM (float32).
- Return type:
nl.ndarray- Returns:
[B, N, D, S], Value gradient in HBM (float32).
- Return type:
nl.ndarray
Notes:
Sequence length S must be divisible by 128.
striped_attention requires use_causal_mask=True.
When B is not divisible by lnc_size, the last batch is handled with duplicate work.
Dimensions:
B: Batch size
N: Number of attention heads
D: Head dimension
This document is relevant for: Trn2, Trn3