This document is relevant for: Inf2
, Trn1
, Trn2
nki.kernels.allocated_fused_rms_norm_qkv#
- nki.kernels.allocated_fused_rms_norm_qkv(hidden, weights, norm_dtype=<class 'numpy.float32'>, eps=1e-06)[source]#
Allocated kernel that computes RMSNorm(hidden) @ wQKV. This kernel is designed to only handle fp16/bf16 tensor types. Internally, normalizations are cast to fp32 to avoid NaN errors.
- Parameters:
hidden (_type_) – Input tensor of the attention block in BSH layout
weights (_type_) – Fused QKV linear weights, assumed to be eltwise-multiplied with RMS norm weight vector (gamma)
out_tensor (_type_) – Output tensor
norm_dtype (_type_, optional) – Data type for RMS norm, should be f32 to avoid NaN. Defaults to nl.float32.
eps (_type_, optional) – RMS norm epsilon term. Defaults to 1e-6.
This document is relevant for: Inf2
, Trn1
, Trn2