Output Projection TKG Kernel API Reference#
This topic provides the API reference for the Output Projection TKG kernel. The kernel computes the output projection operation typically used after an attention block in transformer models, optimized for Token Generation (Decode) use cases.
The kernel supports:
Efficient projection of attention outputs
Optional bias addition
LNC sharding for distributed computation
Optimized memory access patterns
Head dimension packing for improved performance
Flexible output tensor layouts
SBUF output option for kernel fusion
Background#
The Output Projection TKG kernel computes the operation out = attention @ weight + bias, which is commonly used to project the output scores after an attention block in transformer models. This kernel is specifically optimized for Token Generation (Decode) use cases, where the sequence length S is small (often 1 or a small number for speculative decoding).
The kernel employs efficient tiling strategies and memory access patterns to maximize performance on Neuron hardware, with support for sharding across multiple Logical Neuron Cores (LNCs) to handle large hidden dimensions. When LNC>1, the H dimension is sharded across the cores, which avoids the need for any inter-core collective operations as each core produces part of the output tensor.
The input layouts expected for this kernel are different from those for the CTE kernel. In TKG workloads, the S dimension is small, so placing the N dimension next to it allows more efficient GQA implementations by loading multiple heads at once.
API Reference#
Source code for this kernel API can be found at: aws-neuron/nki-library
output_projection_tkg#
- nkilib.core.output_projection.output_projection_tkg.output_projection_tkg(attention, weight, bias, TRANSPOSE_OUT=False, OUT_IN_SB=False)#
Output Projection Kernel optimized for Token Generation (Decode) use cases.
This kernel computes
out = attention @ weight + bias, typically used to project the output scores after an attention block in transformer models.This kernel is optimized for Token Generation (aka Decode) use cases where sequence length
Sis small.- Parameters:
attention (
nl.ndarray) – Input tensor in HBM or SBUF, typically the scores output from an attention block. Shape:[D, B, N, S], whereDis head dimension,Bis batch size,Nis number of heads, andSis sequence length. Indexing:[d, b, n, s].weight (
nl.ndarray) – Weight tensor in HBM. Shape:[N*D, H], whereHis hidden dimension size. Indexing:[n * D + d, h].bias (
nl.ndarray) – Optional bias tensor in HBM. Shape:[1, H]. Indexing:[1, h].TRANSPOSE_OUT (
bool) – Whether to store the output in transposed shape. IfFalse, output shape is[B*S, H]with indexing[b*S+s, h]. IfTrue, output shape is[H_1, H_0, H_2, B*S]with indexing[h_1, h_0, h_2, b*S+s], whereH_0 = logical core size (LNC),H_1 = 128,H_2 = H/(H_0*H_1), such thath = h_0*H_1*H_2 + h_1*H_2 + h_2.OUT_IN_SB (
bool) – IfTrue, output is in SBUF. Else, it is written out to HBM.
- Returns:
Output tensor in HBM or SBUF. Shape depends on
TRANSPOSE_OUTparameter.- Return type:
nl.ndarray
- Data Types:
This kernel supports
nl.float32,nl.float16andnl.bfloat16data types. However, fornl.float32, large inputs may not fit in SBUF.- Dimensions:
B: Batch sizeN: Number of headsS: Sequence lengthH: Hidden dimension sizeD: Head dimension size
Restrictions:
The contract dimension of input and weight tensors must match (
N*D == weight.shape[0])Hidden dimension (
H) needs to be divisible by LNC size since LNC sharding is on the weight hidden dimensionB*Smust be <= 128Head dimension (
D) must be <= 128When
TRANSPOSE_OUTisFalse,Hmust be a multiple of512*LNCWhen
TRANSPOSE_OUTisTrue,Hmust be a multiple of128*LNCWhen
TRANSPOSE_OUTisTrueand using 32-bit floats,N*Hmust be <= 81920When
TRANSPOSE_OUTisTrueand using 16-bit floats,N*Hmust be <= 163840
Implementation Details#
The kernel implementation includes several key optimizations:
Dimension Packing: Optimizes the contraction dimension by folding
N(number of heads) intoD(head dimension) when beneficial, improving computational efficiency.Efficient Tiling Strategy: Uses carefully chosen tile sizes for processing batches and sequences to maximize hardware utilization.
LNC Sharding: Supports sharding across multiple Logical Neuron Cores (LNCs) by dividing the hidden dimension, enabling processing of larger models.
Memory Access Optimization: Employs optimized memory access patterns to maximize bandwidth utilization and minimize data movement.
PSUM Bank Utilization: Efficiently utilizes PSUM banks for accumulating partial results during matrix multiplication operations.
Stream Shuffle Broadcast: Uses stream shuffle broadcast for bias tensors to efficiently distribute them across processing elements.
Flexible Output Layouts: Supports both standard and transposed output layouts to accommodate different downstream kernel requirements.
SBUF Output Option: Provides the option to keep output in SBUF for fusion with subsequent operations.
Block-based Weight Loading: Uses block-based loading of weights to encourage prefetching and improve memory access patterns.