This document is relevant for: Inf2
, Trn1
, Trn2
nki.kernels.flash_attn_bwd#
- nki.kernels.flash_attn_bwd(q_ref, k_ref, v_ref, o_ref, dy_ref, lse_ref, seed_ref, logit_bias_ref=None, use_causal_mask=False, mixed_precision=False, dropout_p=0.0, softmax_scale=None)[source]#
Flash attention backward kernel. Compute the backward gradients.
- IO tensor layouts:
q_ref: shape (bs, nheads, head_size, seq)
k_ref: shape (bs, nheads, head_size, seq)
v_ref: shape (bs, nheads, head_size, seq)
o_ref: shape (bs, nheads, head_size, seq)
dy_ref: shape (bs, nheads, head_size, seq)
lse_ref: shape (bs, nheads, nl.tile_size.pmax, seq // nl.tile_size.pmax)
seed_ref: shape (1,)
logit_bias_ref: shape (bs, n_heads, seq_q, seq_k)
out_dq_ref: shape (bs, nheads, head_size, seq)
out_dk_ref: shape (bs, nheads, head_size, seq)
out_dv_ref: shape (bs, nheads, head_size, seq)
- Detailed steps:
D = rowsum(dO ◦ O) (pointwise multiply)
Recompute (softmax(Q^T@K + logic_bias))
2.1 Q^T@K 2.2 Scale the QK score 2.3 Apply causal mask and add logit_bias 2.4 softmax
Compute the gradients of y = score @ V with respect to the loss
Compute the gradients of y = softmax(x)
Compute the gradients of Q^T@K
4.1 Compute dQ 4.2 Compute dK
This document is relevant for: Inf2
, Trn1
, Trn2