This document is relevant for: Inf2, Trn1, Trn1n

nki.kernels.flash_fwd#

nki.kernels.flash_fwd(q, k, v, seed, o, lse=None, softmax_scale=None, use_causal_mask=True, mixed_precision=True, dropout_p=0.0, config=None)[source]#

Flash Attention Forward kernel

IO tensor layouts:
  • q: shape (bs, n_heads, d, seq_q)

  • k: shape (bs, nk_heads, d, seq_k)

  • v: shape (bs, nv_heads, d, seq_v) if config.should_transpose_v else (bs, nv_heads, seq_v, d)

  • seed: shape (1,)

  • o: shape (bs, n_heads, seq_q, d)

  • lse: shape (bs, nheads, nl.tile_size.pmax, seq // nl.tile_size.pmax) if training else None

  • We use seq_q and seq_k just for clarity, this kernel requires seq_q == seq_k

IO tensor dtypes:
  • This kernel assumes all IO tensors have the same dtype

  • If mixed_percision is True, then all Tensor Engine operation will be performed in bfloat16 and accumulation will be performed in float32. Otherwise the intermediates will be in the same type as the inputs.

Compile-time Constants:
  • softmax_scale: scaling for softmax, is None, default is 1.0/(d**0.5)

  • mixed_precision: flag to set non-matmul ops in fp32 precision, defualt is set to true, if false, we use same precision as input types

  • causal_mask: flag to set causal masking

  • config: Instance of dataclass nki.kernels.attention.FlashConfig with Performance config parameters for flash attention with default values

    seq_tile_size: default=2048, size of the kv tile size for attention computation reduction training: bool to indicate training vs inference default=True

Performance Notes:

For better performance, the kernel is tiled to be of size LARGE_TILE_SZ, and Flash attention math techniques are applied in unit of LARGE_TILE_SZ. Seqlen that is not divisible by LARGE_TILE_SZ is not supported at the moment.

GQA support Notes:

the spmd kernel for launching kernel should be on kv_heads instead of nheads

Example usage:
MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d]

usage: flash_fwd[b, h](q, k, v, …)

GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d]

usage: flash_fwd[b, kv_h](q, k, v, …)

This document is relevant for: Inf2, Trn1, Trn1n