This document is relevant for: Trn2, Trn3
NKI Library Supported Kernel Reference#
The NKI Library provides pre-built reference kernels you can use directly in your model development with the AWS Neuron SDK and NKI. These kernels provide the default classes, functions, and parameters you can use to integrate the NKI Library kernels into your models.
Source code for these kernel APIs can be found at: aws-neuron/nki-library
Core Kernels#
Normalization and Quantization Kernels#
Performs optional RMS normalization followed by quantization to |
QKV Projection Kernels#
Performs Query-Key-Value projection with optional normalization and RoPE fusion. |
Attention Kernels#
Implements attention optimized for Context Encoding (prefill) use cases. |
|
Segmented attention with block-based KV cache and prefix caching for decode. |
|
Implements attention optimized for Token Generation (decode) use cases with small active sequence lengths. |
|
KV-parallel segmented prefill attention kernel. |
Rotary Position Embedding (RoPE) Kernels#
Applies Rotary Position Embedding to input embeddings with flexible layout support. |
Multi-Layer Perceptron (MLP) Kernels#
Implements Multi-Layer Perceptron with optional normalization fusion and quantization support. |
Output Projection Kernels#
Computes output projection optimized for Context Encoding use cases. |
|
Computes output projection optimized for Token Generation use cases. |
Mixture of Experts (MoE) Kernels#
Computes router logits, applies activation functions, and performs top-K selection for MoE models. |
|
Implements Mixture of Experts MLP operations optimized for Context Encoding use cases. |
|
Implements Mixture of Experts MLP operations optimized for Token Generation use cases. |
Quantization Kernels#
Static and row-wise dynamic FP8 quantization with pre-combined dequantization scale support. |
Cumulative Sum Kernels#
Computes cumulative sum along the last dimension with optimized tiling. |
Core Subkernels#
Finds indices of nonzero elements along the T dimension using GpSimd |
Experimental Kernels#
Note
Experimental kernels are under active development and their APIs may change in future releases.
Attention Kernels#
Fused attention block for Token Generation that keeps all intermediate tensors in SBUF to minimize HBM traffic. |
|
Ring attention forward pass for context parallelism across multiple workers. |
|
Ring attention backward pass SPMD kernel for context parallelism. |
Transformer Kernels#
Multi-layer transformer forward pass megakernel for token generation. |
Convolution Kernels#
1D convolution using tensor engine with replication strategy. |
|
3D convolution using tensor engine with K-replication strategy and W-contiguous tiling. |
|
Implements depthwise 1D convolution using implicit GEMM algorithm. |
Collective Communication Kernels#
Ring-based all-gather for TRN2 with double-buffered collective permute. |
|
Fused all-gather and matrix multiplication for TRN2. |
|
SBUF-to-SBUF all-gather with variants for small and large tensors. |
Foreach Kernels#
Suite of fused elementwise operations (add, sub, mul, div, addcdiv, addcmul, lerp, sqrt) with SPMD tiling. |
|
L1, L2, and Linf norm computation kernels with SPMD parallelization. |
Matmul and MLP MXFP8 Kernels#
Generic matrix multiplication with MXFP8 quantization, supporting BF16 and pre-quantized inputs. |
|
MXFP8 SwiGLU MLP forward pass with optional activation checkpointing. |
|
MXFP8 SwiGLU MLP backward pass with 4-phase gradient computation. |
MoE Kernels#
Wrapper that bitcasts unsigned integer weights to MX x4 dtype for MoE block. |
Optimizer Kernels#
Fused Adam (L2 regularization) and AdamW (decoupled weight decay) optimizer kernels. |
Padding Kernels#
Multi-mode tensor padding (constant, replicate, reflect, circular) following PyTorch semantics. |
Quantization Kernels#
Block-wise BF16-to-MXFP8 quantization kernel with packed scale support. |
RNG Kernels#
Random number generation kernels using GPSIMD engine with state management. |
Scan Kernels (State Space Models)#
First-order linear recurrence computation along the last dimension. |
|
Selective scan (SSM) as in Mamba models. |
|
State Space Duality scan for Mamba-2 models. |
MoE Subkernels#
MoE Top-K reduction across sparse all-to-all collective output. |
|
Unstable argsort on 1D input buffer for MoE routing. |
|
Builds metadata buffer for all_to_all_v collective from MoE routing decisions. |
|
Sorts tokens by expert and packs hidden states for MoE dispatch. |
Dynamic Shape Kernels#
Elementwise addition with runtime-variable M-dimension tiling. |
Loss Kernels#
Memory-efficient cross entropy loss forward and backward passes using online log-sum-exp algorithm. |
MoE Backward Kernels#
Computes backward pass for blockwise matrix multiplication in Mixture of Experts layers. |
This document is relevant for: Trn2, Trn3