This document is relevant for: Trn1, Trn2, Trn3
nki.language.rms_norm#
- nki.language.rms_norm(x, w, axis, n, epsilon=1e-06, dtype=None, compute_dtype=None)[source]#
Apply Root Mean Square Layer Normalization.
Warning
This API is experimental and may change in future releases.
- Parameters:
x – input tile.
w – weight tile.
axis – axis along which to compute the root mean square (rms) value.
n – total number of values to calculate rms.
epsilon – epsilon value used by rms calculation to avoid divide-by-zero.
dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tile.
compute_dtype – (optional) dtype for the internal computation.
- Returns:
x / RMS(x) * w
Examples:
import nki.language as nl # nki.language.rms_norm -- normalize with unit weights x = nl.full((128, 512), 2.0, dtype=nl.float32, buffer=nl.sbuf) w = nl.full((128, 512), 1.0, dtype=nl.float32, buffer=nl.sbuf) result = nl.rms_norm(x, w, axis=1, n=512)
This document is relevant for: Trn1, Trn2, Trn3