This document is relevant for: Inf2
, Trn1
, Trn1n
nki.isa.bn_stats#
- nki.isa.bn_stats(data, mask=None, dtype=None, **kwargs)[source]#
Compute mean- and variance-related statistics for each partition of an input tile
data
in parallel using Vector Engine.The output tile of the instruction has 6 elements per partition:
the
count
of the even elements (of the input tile elements from the same partition)the
mean
of the even elementsvariance * count
of the even elementsthe
count
of the odd elementsthe
mean
of the odd elementsvariance * count
of the odd elements
To get the final mean and variance of the input tile, we need to pass the above
bn_stats
instruction output into the bn_aggr instruction, which will output two elements per partition:mean (of the original input tile elements from the same partition)
variance
Due to hardware limitation, the number of elements per partition (i.e., free dimension size) of the input
data
must not exceed 512 (nl.tile_size.bn_stats_fmax). To calculate per-partition mean/variance of a tensor with more than 512 elements in free dimension, we can invokebn_stats
instructions on each 512-element tile and use a singlebn_aggr
instruction to aggregatebn_stats
outputs from all the tiles. Refer to Example 2 for an example implementation.Vector Engine performs the above statistics calculation in float32 precision. Therefore, the engine automatically casts the input
data
tile to float32 before performing float32 computation and is capable of casting the float32 computation results into another data type specified by thedtype
field, at no additional performance cost. Ifdtype
field is not specified, the instruction will cast the float32 results back to the same data type as the inputdata
tile.Estimated instruction cost:
N
Vector Engine cycles, whereN
is the number of elements per partition indata
.- Parameters:
data – the input tile (up to 512 elements per partition)
mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)
dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tile.
- Returns:
an output tile with 6-element statistics per partition
Example:
import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl import numpy as np ... ################################################################## # Example 1: Calculate the mean and variance for each partition # of tile a with shape (128, 128) ################################################################## i_p_a = nl.arange(128)[:, None] i_f_a = nl.arange(128)[None, :] a = nl.load(a_tensor[i_p_a, i_f_a]) stats_a = nisa.bn_stats(a[i_p_a, i_f_a]) assert stats_a.shape == (128, 6) mean_var_a = nisa.bn_aggr(stats_a) assert mean_var_a.shape == (128, 2) # Extract mean and variance mean_a = mean_var_a[:, 0] var_a = mean_var_a[:, 1] nl.store(mean_a_tensor[i_p_a, 0], mean_a) nl.store(var_a_tensor[i_p_a, 0], var_a) # ################################################################## # # Example 2: Calculate the mean and variance for each partition of # # tile b with shape [128, 1024] # ################################################################## i_p_b = nl.arange(128)[:, None] i_f_b = nl.arange(nl.tile_size.bn_stats_fmax)[None, :] i_f_b2 = nl.arange(1024)[None, :] b = nl.load(b_tensor[i_p_b, i_f_b2]) # Run bn_stats in two tiles because b has 1024 elements per partition, # but bn_stats has a limitation of nl.tile_size.bn_stats_fmax # Initialize a bn_stats output tile with shape of [128, 6*2] to # hold outputs of two bn_stats instructions stats_b = nl.ndarray((128, 6*2), dtype=np.float32) i_p_stats_b = nl.arange(128)[:, None] i_f_stats_b = nl.arange(6)[None, :] stats_b[i_p_stats_b, i_f_stats_b] = nisa.bn_stats(b[i_p_b, i_f_b], dtype=np.float32) stats_b[i_p_stats_b, 6+i_f_stats_b] = nisa.bn_stats(b[i_p_b, nl.tile_size.bn_stats_fmax+i_f_b], dtype=np.float32) mean_var_b = nisa.bn_aggr(stats_b) # Extract mean and variance mean_b = mean_var_b[:, 0] var_b = mean_var_b[:, 1] nl.store(mean_b_tensor[i_p_b, 0], mean_b) nl.store(var_b_tensor[i_p_b, 0], var_b)
This document is relevant for: Inf2
, Trn1
, Trn1n