This document is relevant for: Inf2, Trn1, Trn1n

nki.kernels.flash_attn_bwd#

nki.kernels.flash_attn_bwd(q_ref, k_ref, v_ref, o_ref, dy_ref, lse_ref, seed_ref, out_dq_ref, out_dk_ref, out_dv_ref, use_causal_mask=False, mixed_precision=False, dropout_p=0.0, softmax_scale=None)[source]#

Flash attention backward kernel. Compute the backward gradients.

IO tensor layouts:
  • q_ref: shape (bs, nheads, head_size, seq)

  • k_ref: shape (bs, nheads, head_size, seq)

  • v_ref: shape (bs, nheads, head_size, seq)

  • o_ref: shape (bs, nheads, head_size, seq)

  • dy_ref: shape (bs, nheads, head_size, seq)

  • lse_ref: shape (bs, nheads, nl.tile_size.pmax, seq // nl.tile_size.pmax)

  • seed_ref: shape (1,)

  • out_dq_ref: shape (bs, nheads, head_size, seq)

  • out_dk_ref: shape (bs, nheads, head_size, seq)

  • out_dv_ref: shape (bs, nheads, head_size, seq)

Detailed steps:
  1. D = rowsum(dO ◦ O) (pointwise multiply)

  2. Recompute (softmax(Q^T@K))

2.1 Q^T@K 2.2 Scale the QK score 2.3 Apply causal mask 2.4 softmax

  1. Compute the gradients of y = score @ V with respect to the loss

  2. Compute the gradients of y = softmax(x)

  3. Compute the gradients of Q^T@K

4.1 Compute dQ 4.2 Compute dK

This document is relevant for: Inf2, Trn1, Trn1n