This document is relevant for: Trn2, Trn3

NKI Library Supported Kernel Reference#

The NKI Library provides pre-built reference kernels you can use directly in your model development with the AWS Neuron SDK and NKI. These kernels provide the default classes, functions, and parameters you can use to integrate the NKI Library kernels into your models.

Source code for these kernel APIs can be found at: aws-neuron/nki-library

Core Kernels#

Normalization and Quantization Kernels#

RMSNorm-Quant

Performs optional RMS normalization followed by quantization to fp8.

QKV Projection Kernels#

QKV

Performs Query-Key-Value projection with optional normalization and RoPE fusion.

Attention Kernels#

Attention CTE

Implements attention optimized for Context Encoding (prefill) use cases.

Attention Segmented CTE

Segmented attention with block-based KV cache and prefix caching for decode.

Attention TKG

Implements attention optimized for Token Generation (decode) use cases with small active sequence lengths.

KV-Parallel Segmented Prefill

KV-parallel segmented prefill attention kernel.

Rotary Position Embedding (RoPE) Kernels#

RoPE

Applies Rotary Position Embedding to input embeddings with flexible layout support.

Multi-Layer Perceptron (MLP) Kernels#

MLP

Implements Multi-Layer Perceptron with optional normalization fusion and quantization support.

Output Projection Kernels#

Output Projection CTE

Computes output projection optimized for Context Encoding use cases.

Output Projection TKG

Computes output projection optimized for Token Generation use cases.

Mixture of Experts (MoE) Kernels#

Router Top-K

Computes router logits, applies activation functions, and performs top-K selection for MoE models.

MoE CTE

Implements Mixture of Experts MLP operations optimized for Context Encoding use cases.

MoE TKG

Implements Mixture of Experts MLP operations optimized for Token Generation use cases.

Quantization Kernels#

FP8 Quantize

Static and row-wise dynamic FP8 quantization with pre-combined dequantization scale support.

Cumulative Sum Kernels#

Cumsum

Computes cumulative sum along the last dimension with optimized tiling.

Core Subkernels#

Find Nonzero Indices

Finds indices of nonzero elements along the T dimension using GpSimd nonzero_with_count ISA.

Experimental Kernels#

Note

Experimental kernels are under active development and their APIs may change in future releases.

Attention Kernels#

Attention Block TKG

Fused attention block for Token Generation that keeps all intermediate tensors in SBUF to minimize HBM traffic.

Ring Attention Forward

Ring attention forward pass for context parallelism across multiple workers.

Ring Attention Backward

Ring attention backward pass SPMD kernel for context parallelism.

Transformer Kernels#

Transformer TKG

Multi-layer transformer forward pass megakernel for token generation.

Convolution Kernels#

Conv1D

1D convolution using tensor engine with replication strategy.

Conv3D

3D convolution using tensor engine with K-replication strategy and W-contiguous tiling.

Depthwise Conv1D

Implements depthwise 1D convolution using implicit GEMM algorithm.

Collective Communication Kernels#

Fine-Grained All-Gather

Ring-based all-gather for TRN2 with double-buffered collective permute.

FGCC (All-Gather + Matmul)

Fused all-gather and matrix multiplication for TRN2.

SBUF-to-SBUF All-Gather

SBUF-to-SBUF all-gather with variants for small and large tensors.

Foreach Kernels#

Foreach Elementwise

Suite of fused elementwise operations (add, sub, mul, div, addcdiv, addcmul, lerp, sqrt) with SPMD tiling.

Foreach Norm

L1, L2, and Linf norm computation kernels with SPMD parallelization.

Matmul and MLP MXFP8 Kernels#

Matmul MXFP8

Generic matrix multiplication with MXFP8 quantization, supporting BF16 and pre-quantized inputs.

MLP Forward MXFP8

MXFP8 SwiGLU MLP forward pass with optional activation checkpointing.

MLP Backward MXFP8

MXFP8 SwiGLU MLP backward pass with 4-phase gradient computation.

MoE Kernels#

MX MoE Block TKG Wrapper

Wrapper that bitcasts unsigned integer weights to MX x4 dtype for MoE block.

Optimizer Kernels#

Fused Adam/AdamW

Fused Adam (L2 regularization) and AdamW (decoupled weight decay) optimizer kernels.

Padding Kernels#

Pad

Multi-mode tensor padding (constant, replicate, reflect, circular) following PyTorch semantics.

Quantization Kernels#

Quantize MXFP8

Block-wise BF16-to-MXFP8 quantization kernel with packed scale support.

RNG Kernels#

RNG

Random number generation kernels using GPSIMD engine with state management.

Scan Kernels (State Space Models)#

Linear Scan

First-order linear recurrence computation along the last dimension.

Selective Scan

Selective scan (SSM) as in Mamba models.

SSD

State Space Duality scan for Mamba-2 models.

MoE Subkernels#

Top-K Reduce

MoE Top-K reduction across sparse all-to-all collective output.

Argsort Unstable

Unstable argsort on 1D input buffer for MoE routing.

Build All-to-All-V Metadata

Builds metadata buffer for all_to_all_v collective from MoE routing decisions.

Permute Routed Tokens

Sorts tokens by expert and packs hidden states for MoE dispatch.

Dynamic Shape Kernels#

Dynamic Elementwise Add

Elementwise addition with runtime-variable M-dimension tiling.

Loss Kernels#

Cross Entropy

Memory-efficient cross entropy loss forward and backward passes using online log-sum-exp algorithm.

MoE Backward Kernels#

Blockwise MM Backward

Computes backward pass for blockwise matrix multiplication in Mixture of Experts layers.

This document is relevant for: Trn2, Trn3