This document is relevant for: Trn2, Trn3

Ring Attention Bwd Kernel API Reference#

Ring Attention Backward SPMD kernel.

Computes gradients dQ, dK, dV for ring attention using collective permute operations to circulate Q, dY, LSE, and dy_o_sum across workers while keeping K, V local. Supports causal masking and striped attention.

Background#

The ring_attention_spmd_bwd kernel computes the backward pass for ring attention, producing gradients dQ, dK, and dV using collective permute operations to circulate tensors across workers while keeping K and V local to each worker.

API Reference#

Source code for this kernel API can be found at: ring_attention_bwd.py

ring_attention_spmd_bwd#

nkilib.experimental.attention.ring_attention_spmd_bwd(q_ref: nl.ndarray, k_ref: nl.ndarray, v_ref: nl.ndarray, o_ref: nl.ndarray, dy_ref: nl.ndarray, lse_ref: nl.ndarray, use_causal_mask: bool = False, mixed_precision: bool = True, softmax_scale: float = None, num_workers: int = 1, lnc_size: int = 1, replica_groups: tuple = None, striped_attention: bool = False)#

Ring Attention Backward SPMD kernel.

Parameters:
  • q_ref (nl.ndarray) – [B, N, D, S], Query tensor in HBM.

  • k_ref (nl.ndarray) – [B, N, D, S], Key tensor in HBM.

  • v_ref (nl.ndarray) – [B, N, D, S], Value tensor in HBM.

  • o_ref (nl.ndarray) – [B, N, D, S], Forward output tensor in HBM.

  • dy_ref (nl.ndarray) – [B, N, D, S], Upstream gradient tensor in HBM.

  • lse_ref (nl.ndarray) – [B, N, 128, S//128], Log-sum-exp from forward pass in HBM.

  • use_causal_mask (bool) – Whether to apply causal masking. Default: False.

  • mixed_precision (bool) – Whether to use mixed precision (fp32 accumulators). Default: True.

  • softmax_scale (float) – Softmax scale factor. Default: 1/sqrt(D).

  • num_workers (int) – Number of workers in the ring. Default: 1.

  • lnc_size (int) – LNC size (number of logical cores). Default: 1.

  • replica_groups (tuple) – Replica groups for collective communication. Default: None.

  • striped_attention (bool) – Whether to use striped attention layout. Default: False.

Returns:

[B, N, D, S], Query gradient in HBM (float32).

Return type:

nl.ndarray

Returns:

[B, N, D, S], Key gradient in HBM (float32).

Return type:

nl.ndarray

Returns:

[B, N, D, S], Value gradient in HBM (float32).

Return type:

nl.ndarray

Notes:

  • Sequence length S must be divisible by 128.

  • striped_attention requires use_causal_mask=True.

  • When B is not divisible by lnc_size, the last batch is handled with duplicate work.

Dimensions:

  • B: Batch size

  • N: Number of attention heads

  • D: Head dimension

This document is relevant for: Trn2, Trn3