This document is relevant for: Inf2, Trn1, Trn2

nki.kernels.allocated_fused_rms_norm_qkv#

nki.kernels.allocated_fused_rms_norm_qkv(hidden, weights, norm_dtype=<class 'numpy.float32'>, eps=1e-06)[source]#

Allocated kernel that computes RMSNorm(hidden) @ wQKV. This kernel is designed to only handle fp16/bf16 tensor types. Internally, normalizations are cast to fp32 to avoid NaN errors.

Parameters:
  • hidden (_type_) – Input tensor of the attention block in BSH layout

  • weights (_type_) – Fused QKV linear weights, assumed to be eltwise-multiplied with RMS norm weight vector (gamma)

  • out_tensor (_type_) – Output tensor

  • norm_dtype (_type_, optional) – Data type for RMS norm, should be f32 to avoid NaN. Defaults to nl.float32.

  • eps (_type_, optional) – RMS norm epsilon term. Defaults to 1e-6.

This document is relevant for: Inf2, Trn1, Trn2