This document is relevant for: Inf2, Trn1, Trn2

nki.kernels.flash_attn_bwd#

nki.kernels.flash_attn_bwd(q_ref, k_ref, v_ref, o_ref, dy_ref, lse_ref, seed_ref, logit_bias_ref=None, use_causal_mask=False, mixed_precision=False, dropout_p=0.0, softmax_scale=None)[source]#

Flash attention backward kernel. Compute the backward gradients.

IO tensor layouts:
  • q_ref: shape (bs, nheads, head_size, seq)

  • k_ref: shape (bs, nheads, head_size, seq)

  • v_ref: shape (bs, nheads, head_size, seq)

  • o_ref: shape (bs, nheads, head_size, seq)

  • dy_ref: shape (bs, nheads, head_size, seq)

  • lse_ref: shape (bs, nheads, nl.tile_size.pmax, seq // nl.tile_size.pmax)

  • seed_ref: shape (1,)

  • logit_bias_ref: shape (bs, n_heads, seq_q, seq_k)

  • out_dq_ref: shape (bs, nheads, head_size, seq)

  • out_dk_ref: shape (bs, nheads, head_size, seq)

  • out_dv_ref: shape (bs, nheads, head_size, seq)

Detailed steps:
  1. D = rowsum(dO ◦ O) (pointwise multiply)

  2. Recompute (softmax(Q^T@K + logic_bias))

2.1 Q^T@K 2.2 Scale the QK score 2.3 Apply causal mask and add logit_bias 2.4 softmax

  1. Compute the gradients of y = score @ V with respect to the loss

  2. Compute the gradients of y = softmax(x)

  3. Compute the gradients of Q^T@K

4.1 Compute dQ 4.2 Compute dK

This document is relevant for: Inf2, Trn1, Trn2