This document is relevant for: Inf2, Trn1, Trn2
nki.isa.range_select#
- nki.isa.range_select(*, on_true_tile, comp_op0, comp_op1, bound0, bound1, reduce_cmd=reduce_cmd.idle, reduce_res=None, reduce_op=<function amax>, range_start=0, on_false_value=<property object>, mask=None, dtype=None, **kwargs)[source]#
Select elements from
on_true_tilebased on comparison with bounds using Vector Engine.For each element in
on_true_tile, compares its free dimension index +range_startagainstbound0andbound1using the specified comparison operators (comp_op0andcomp_op1). If both comparisons evaluate to True, copies the element to the output; otherwise useson_false_value.Additionally performs a reduction operation specified by
reduce_opon the results, storing the reduction result inreduce_res.Note on numerical stability:
In self-attention, we often have this instruction sequence:
range_select(VectorE) ->reduce_res->activation(ScalarE). Whenrange_selectoutputs a full row offill_value, caution is needed to avoid NaN in the activation instruction that subtracts the output ofrange_selectbyreduce_res(max value):If
dtypeandreduce_resare both FP32, we should not hit any NaN issue sinceFP32_MIN - FP32_MIN = 0. Exponentiation on 0 is stable (1.0 exactly).If
dtypeis FP16/BF16/FP8, the fill_value in the output tile will become-INFsince HW performs a downcast from FP32_MIN to a smaller dtype. In this case, you must make sure reduce_res uses FP32dtypeto avoid NaN inactivation. NaN can be avoided becauseactivationalways upcasts input tiles to FP32 to perform math operations:-INF - FP32_MIN = -INF. Exponentiation on-INFis stable (0.0 exactly).
Constraints:
The comparison operators must be one of:
np.equal
np.less
np.less_equal
np.greater
np.greater_equal
Partition dim sizes must match across
on_true_tile,bound0, andbound1:bound0andbound1must have one element per partitionon_true_tilemust be one of the FP dtypes, andbound0/bound1must be FP32 types.
The comparison with
bound0,bound1, and free dimension index is done in FP32. Make surerange_start+ free dimension index is within 2^24 range.Estimated instruction cost:
max(MIN_II, N)Vector Engine cycles, where:Nis the number of elements per partition inon_true_tile, andMIN_IIis the minimum instruction initiation interval for small input tiles.MIN_IIis roughly 64 engine cycles.
Numpy equivalent:
indices = np.zeros(on_true_tile.shape) indices[:] = range_start + np.arange(on_true_tile[0].size) mask = comp_op0(indices, bound0) & comp_op1(indices, bound1) select_out_tile = np.where(mask, on_true_tile, on_false_value) reduce_tile = reduce_op(select_out_tile, axis=1, keepdims=True)
- Parameters:
on_true_tile – input tile containing elements to select from
on_false_value – constant value to use when selection condition is False. Due to HW constraints, this must be FP32_MIN FP32 bit pattern
comp_op0 – first comparison operator
comp_op1 – second comparison operator
bound0 – tile with one element per partition for first comparison
bound1 – tile with one element per partition for second comparison
reduce_op – reduction operator to apply on across the selected output. Currently only
np.maxis supported.reduce_res – optional tile to store reduction results.
range_start – starting base offset for index array for the free dimension of
on_true_tileDefaults to 0, and must be a compiler time integer.mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)
dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tile.
- Returns:
output tile with selected elements
Example:
import neuronxcc.nki as nki import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl import numpy as np ... ################################################################## # Example 1: # Select elements where # bound0 <= range_start + index < bound1 and compute max reduction # # on_false_value must be nl.fp32.min ################################################################## on_true_tile = nl.load(on_true[...]) bound0_tile = nl.load(bound0[...]) bound1_tile = nl.load(bound1[...]) reduce_res_tile = nl.ndarray((on_true.shape[0], 1), dtype=nl.float32, buffer=nl.sbuf) result = nl.ndarray(on_true.shape, dtype=nl.float32, buffer=nl.sbuf) result[...] = nisa.range_select( on_true_tile=on_true_tile, comp_op0=compare_op0, comp_op1=compare_op1, bound0=bound0_tile, bound1=bound1_tile, reduce_cmd=nisa.reduce_cmd.reset_reduce, reduce_res=reduce_res_tile, reduce_op=np.max, range_start=range_start, on_false_value=nl.fp32.min ) nl.store(select_res[...], value=result[...]) nl.store(reduce_result[...], value=reduce_res_tile[...])
Alternatively,
reduce_cmdcan be used to chain multiple calls to the same accumulation register to accumulate across multiple range_select calls. For example:import neuronxcc.nki as nki import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl import numpy as np ... ################################################################## # Example 2.a: Initialize reduction with first range_select # Notice we don't pass reduce_res since the accumulation # register keeps track of the accumulation until we're ready to # read it. Also we use reset_reduce in order to "clobber" or zero # out the accumulation register before we start accumulating. # # Note: Since the type of these tensors are fp32, we use nl.fp32.min # for on_false_value due to HW constraints. ################################################################## on_true_tile = nl.load(on_true[...]) bound0_tile = nl.load(bound0[...]) bound1_tile = nl.load(bound1[...]) reduce_res_sbuf = nl.ndarray((on_true.shape[0], 1), dtype=np.float32, buffer=nl.sbuf) result_sbuf = nl.ndarray(on_true.shape, dtype=np.float32, buffer=nl.sbuf) result_sbuf[...] = nisa.range_select( on_true_tile=on_true_tile, comp_op0=compare_op0, comp_op1=compare_op1, bound0=bound0_tile, bound1=bound1_tile, reduce_cmd=nisa.reduce_cmd.reset_reduce, reduce_op=np.max, range_start=range_start, on_false_value=nl.fp32.min ) ################################################################## # Example 2.b: Chain multiple range_select operations # with reduction in an affine loop. Adding ones just lets us ensure the reduction # gets updated with new values. ################################################################## ones = nl.full(on_true.shape, fill_value=1, dtype=np.float32, buffer=nl.sbuf) # we are going to loop as if we're tiling on the partition dimension iteration_step_size = on_true_tile.shape[0] # Perform chained operations using an affine loop index for range_start for i in range(1, 2): # Update input values on_true_tile[...] = nl.add(on_true_tile, ones) # Continue reduction with updated values # notice, we still don't have reduce_res specified result_sbuf[...] = nisa.range_select( on_true_tile=on_true_tile, comp_op0=compare_op0, comp_op1=compare_op1, bound0=bound0_tile, bound1=bound1_tile, reduce_cmd=nisa.reduce_cmd.reduce, reduce_op=np.max, # we can also use index expressions for setting the start of the range range_start=range_start + (i * iteration_step_size), on_false_value=nl.fp32.min ) range_start = range_start + (2 * iteration_step_size) ################################################################## # Example 2.c: Final iteration, we actually want the results to # return to the user so we pass reduce_res argument so the # reduction will be written from the accumulation # register to reduce_res_tile ################################################################## on_true_tile[...] = nl.add(on_true_tile, ones) result_sbuf[...] = nisa.range_select( on_true_tile=on_true_tile, comp_op0=compare_op0, comp_op1=compare_op1, bound0=bound0_tile, bound1=bound1_tile, reduce_cmd=nisa.reduce_cmd.reduce, reduce_res=reduce_res_sbuf[...], reduce_op=np.max, range_start=range_start, on_false_value=nl.fp32.min ) nl.store(select_res[...], value=result_sbuf[...]) nl.store(reduce_result[...], value=reduce_res_sbuf[...])
This document is relevant for: Inf2, Trn1, Trn2