This document is relevant for: Inf2
, Trn1
, Trn1n
nki.kernels.flash_fwd#
- nki.kernels.flash_fwd(q, k, v, seed, o, lse=None, softmax_scale=None, use_causal_mask=True, mixed_precision=True, dropout_p=0.0, config=None)[source]#
Flash Attention Forward kernel
- IO tensor layouts:
q: shape (bs, n_heads, d, seq_q)
k: shape (bs, nk_heads, d, seq_k)
v: shape (bs, nv_heads, d, seq_v) if config.should_transpose_v else (bs, nv_heads, seq_v, d)
seed: shape (1,)
o: shape (bs, n_heads, seq_q, d)
lse: shape (bs, nheads, nl.tile_size.pmax, seq // nl.tile_size.pmax) if training else None
We use seq_q and seq_k just for clarity, this kernel requires seq_q == seq_k
- IO tensor dtypes:
This kernel assumes all IO tensors have the same dtype
If mixed_percision is True, then all Tensor Engine operation will be performed in bfloat16 and accumulation will be performed in float32. Otherwise the intermediates will be in the same type as the inputs.
- Compile-time Constants:
softmax_scale: scaling for softmax, is None, default is 1.0/(d**0.5)
mixed_precision: flag to set non-matmul ops in fp32 precision, defualt is set to true, if false, we use same precision as input types
causal_mask: flag to set causal masking
- config: Instance of dataclass
nki.kernels.attention.FlashConfig
with Performance config parameters for flash attention with default values seq_tile_size: default=2048, size of the kv tile size for attention computation reduction training: bool to indicate training vs inference default=True
- config: Instance of dataclass
- Performance Notes:
For better performance, the kernel is tiled to be of size LARGE_TILE_SZ, and Flash attention math techniques are applied in unit of LARGE_TILE_SZ. Seqlen that is not divisible by LARGE_TILE_SZ is not supported at the moment.
- GQA support Notes:
the spmd kernel for launching kernel should be on kv_heads instead of nheads
- Example usage:
- MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d]
usage: flash_fwd[b, h](q, k, v, …)
- GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d]
usage: flash_fwd[b, kv_h](q, k, v, …)
This document is relevant for: Inf2
, Trn1
, Trn1n