nki.isa.bn_stats#

nki.isa.bn_stats(dst, data, name=None)[source]#

Compute mean- and variance-related statistics for each partition of an input tile data in parallel using Vector Engine.

The output tile of the instruction has 6 elements per partition:

  • the count of the even elements (of the input tile elements from the same partition)

  • the mean of the even elements

  • variance * count of the even elements

  • the count of the odd elements

  • the mean of the odd elements

  • variance * count of the odd elements

To get the final mean and variance of the input tile, we need to pass the above bn_stats instruction output into the bn_aggr instruction, which will output two elements per partition:

  • mean (of the original input tile elements from the same partition)

  • variance

Due to hardware limitation, the number of elements per partition (i.e., free dimension size) of the input data must not exceed 512 (nl.tile_size.bn_stats_fmax). To calculate per-partition mean/variance of a tensor with more than 512 elements in free dimension, we can invoke bn_stats instructions on each 512-element tile and use a single bn_aggr instruction to aggregate bn_stats outputs from all the tiles.

Vector Engine performs the above statistics calculation in float32 precision. Therefore, the engine automatically casts the input data tile to float32 before performing float32 computation and is capable of casting the float32 computation results into another data type specified by the dtype field, at no additional performance cost. If dtype field is not specified, the instruction will cast the float32 results back to the same data type as the input data tile.

Parameters:
  • dst – an output tile with 6-element statistics per partition

  • data – the input tile (up to 512 elements per partition)