Source code for nki.kernels

import numpy as np
import ml_dtypes

[docs]def flash_attn_bwd(q_ref, k_ref, v_ref, o_ref, dy_ref, lse_ref, seed_ref, out_dq_ref, out_dk_ref, out_dv_ref, use_causal_mask=False, mixed_precision=False, dropout_p=0.0, softmax_scale=None): r""" Flash attention backward kernel. Compute the backward gradients. IO tensor layouts: - q_ref: shape (bs, nheads, head_size, seq) - k_ref: shape (bs, nheads, head_size, seq) - v_ref: shape (bs, nheads, head_size, seq) - o_ref: shape (bs, nheads, head_size, seq) - dy_ref: shape (bs, nheads, head_size, seq) - lse_ref: shape (bs, nheads, nl.tile_size.pmax, seq // nl.tile_size.pmax) - seed_ref: shape (1,) - out_dq_ref: shape (bs, nheads, head_size, seq) - out_dk_ref: shape (bs, nheads, head_size, seq) - out_dv_ref: shape (bs, nheads, head_size, seq) Detailed steps: 1. D = rowsum(dO ◦ O) (pointwise multiply) 2. Recompute (softmax(Q^T@K)) 2.1 Q^T@K 2.2 Scale the QK score 2.3 Apply causal mask 2.4 softmax 3. Compute the gradients of y = score @ V with respect to the loss 4. Compute the gradients of y = softmax(x) 5. Compute the gradients of Q^T@K 4.1 Compute dQ 4.2 Compute dK """ ...
[docs]def flash_fwd(q, k, v, seed, o, lse=None, softmax_scale=None, use_causal_mask=True, mixed_precision=True, dropout_p=0.0, config=None): r""" Flash Attention Forward kernel IO tensor layouts: - q: shape (bs, n_heads, d, seq_q) - k: shape (bs, nk_heads, d, seq_k) - v: shape (bs, nv_heads, d, seq_v) if config.should_transpose_v else (bs, nv_heads, seq_v, d) - seed: shape (1,) - o: shape (bs, n_heads, seq_q, d) - lse: shape (bs, nheads, nl.tile_size.pmax, seq // nl.tile_size.pmax) if training else None - We use seq_q and seq_k just for clarity, this kernel requires seq_q == seq_k IO tensor dtypes: - This kernel assumes all IO tensors have the same dtype - If mixed_percision is True, then all Tensor Engine operation will be performed in bfloat16 and accumulation will be performed in float32. Otherwise the intermediates will be in the same type as the inputs. Compile-time Constants: - softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)` - mixed_precision: flag to set non-matmul ops in fp32 precision, defualt is set to `true`, if false, we use same precision as input types - causal_mask: flag to set causal masking - config: Instance of dataclass :class:`nki.kernels.attention.FlashConfig` with Performance config parameters for flash attention with default values seq_tile_size: `default=2048`, size of the kv tile size for attention computation reduction training: bool to indicate training vs inference `default=True` Performance Notes: For better performance, the kernel is tiled to be of size `LARGE_TILE_SZ`, and Flash attention math techniques are applied in unit of `LARGE_TILE_SZ`. Seqlen that is not divisible by `LARGE_TILE_SZ` is not supported at the moment. GQA support Notes: the spmd kernel for launching kernel should be on kv_heads instead of nheads Example usage: MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d] usage: `flash_fwd[b, h](q, k, v, ...)` GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d] usage: `flash_fwd[b, kv_h](q, k, v, ...)` """ ...
[docs]def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_causal_mask=False, mixed_percision=True): r""" Fused self attention kernel for small head size Stable Diffusion workload. Computes softmax(QK^T)V. Decoder model can optionally include a causal mask application. Does not include QKV rojection, output projection, dropout, residual connection, etc. This kernel is designed to be used for Stable Diffusion models where the n_heads is smaller or equal to 128. Assertion is thrown if `n_heads` does not satisfy the requirement. IO tensor layouts: - q_ptr: shape (bs, n_heads, seq_q) - k_ptr: shape (bs, seq_k, n_heads) - v_ptr: shape (bs, seq_v, n_heads) - out_ptr: shape (bs, seq_q, n_heads) - We use seq_q and seq_k just for clarity, this kernel requires seq_q == seq_k IO tensor dtypes: - This kernel assumes all IO tensors have the same dtype - If mixed_percision is True, then all Tensor Engine operation will be performed in bfloat16 and accumulation will be performed in float32. Otherwise the intermediates will be in the same type as the inputs. """ ...
[docs]def resize_nearest_fixed_dma_kernel(data_tensor, out_tensor): r""" Resize the input image to the given size using the nearest interpolation mode. This kernel is designed to be used when the scaling factor is not an integer. Example: - Input height : 30, Input width : 20 - Output height : 59, Output width : 38 IO tensor layouts: - data_tensor: shape (in_b, in_h, in_w, in_c) - out_tensor: shape (out_b, out_h, out_w, out_c) - b : batch, c : channel, h : height, w : width - This kernel requires in_b == out_b as input batch and output batch must be identical - This kernel requires in_c == out_c as input channel and output channel must be identical """ ...
[docs]def select_and_scatter_kernel(operand_tensor, source_tensor, out_tensor): r""" Implementation of a select-and-scatter kernel. It selects an element from each window of operand_tensor, and then scatters source_tensor to the indices of the selected positions to construct out_tensor with the same shape as the operand_tensor. This kernel assumes that - windows dimensions: (3, 3) - windows strides: (2, 2) - padding: (1, 1) - init value: 0 - select computation: greater-than - scatter computation: add IO Tensor layouts: - operand_tensor: shape (n, c, h, w) - source_tensor : shape (n, c, src_h, src_w) - out_tensor : shape (n, c, h, w) IO tensor dtypes: - This kernel assumes all IO tensors have the same dtype """ ...