This document is relevant for: Inf2
, Trn1
, Trn2
nki.isa.select_reduce#
- nki.isa.select_reduce(*, dst, predicate, on_true, on_false, reduce_res=None, reduce_cmd=reduce_cmd.idle, reduce_op=<function amax>, reverse_pred=False, mask=None, dtype=None, **kwargs)[source]#
Selectively copy elements from either
on_true
oron_false
to the destination tile based on apredicate
using Vector Engine, with optional reduction (max).The operation can be expressed in NumPy as:
# Select: predicate = ~predicate if reverse_pred else predicate result = np.where(predicate, on_true, on_false) # With Reduce: reduction_result = np.max(result, axis=1, keepdims=True)
Memory constraints:
Both
on_true
andpredicate
are permitted to be in SBUFEither
on_true
orpredicate
may be in PSUM, but not both simultaneouslyThe destination
dst
can be in either SBUF or PSUM
Shape and data type constraints:
on_true
,dst
, andpredicate
must have identical shapes (same number of partitions and elements per partition)on_true
can be any supported dtype excepttfloat32
,int32
,uint32
on_false
dtype must befloat32
ifon_false
is a scalar.on_false
has to be either scalar or vector of shape(on_true.shape[0], 1)
predicate
dtype can be any supported integer typeint8
,uint8
,int16
,uint16
reduce_res
must be a vector of shape(on_true.shape[0], 1)
reduce_res
dtype must of float typereduce_op
only supportsmax
Behavior:
Where predicate is True: The corresponding elements from
on_true
are copied todst
Where predicate is False: The corresponding elements from
on_false
are copied todst
When reduction is enabled, the max value from each partition of the
result
is computed and stored inreduce_res
Accumulator behavior:
The Vector Engine maintains internal accumulator registers that can be controlled via the
reduce_cmd
parameter:nisa.reduce_cmd.reset_reduce
: Reset accumulators to -inf, then accumulate the current resultsnisa.reduce_cmd.reduce
: Continue accumulating without resetting (useful for multi-step reductions)nisa.reduce_cmd.idle
: No accumulation performed (default)
Note
Even when
reduce_cmd
is set toidle
, the accumulator state may still be modified. Always usereset_reduce
after any operations that ran withidle
mode to ensure consistent behavior.Note
The accumulator registers are shared for other Vector Engine accumulation instructions such nki.isa.range_select
- Parameters:
dst – The destination tile to write the selected values to
predicate – Tile that determines which value to select (on_true or on_false)
on_true – Tile to select from when predicate is True
on_false – Value to use when predicate is False, can be a scalar value or a vector tile of
(on_true.shape[0], 1)
reduce_res – (optional) Tile to store reduction results, must have shape
(on_true.shape[0], 1)
reduce_cmd – (optional) Control accumulator behavior using
nisa.reduce_cmd
values, defaults to idlereduce_op – (optional) Reduction operator to apply (only
np.max
is supported)reverse_pred – (optional) Reverse the meaning of the predicate condition, defaults to False
mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)
dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tile.
Example 1: Basic selection
import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl ################################################################## # Example 1: Basic usage of select_reduce # Create source data, predicate, and destination tensors ################################################################## # Create output tensor for result result_tensor = nl.ndarray(on_true_data.shape, dtype=nl.float32, buffer=nl.hbm) # Load input data to SBUF predicate = nl.load(predicate_data[...]) on_true = nl.load(on_true_data[...]) # Create destination tensor dst = nl.ndarray(on_true_data.shape, dtype=nl.float32, buffer=nl.sbuf) # Perform select operation - copy from on_true where predicate is true # and set to fp32.min where predicate is false nisa.select_reduce( dst=dst, predicate=predicate, on_true=on_true, on_false=nl.fp32.min, ) # Store result to HBM nl.store(result_tensor, value=dst)
Example 2: Selection with reduction
################################################################## # Example 2: Using select_reduce with reduction # Perform selection and compute max reduction per partition ################################################################## # Create output tensors for results result_tensor = nl.ndarray(on_true_data.shape, dtype=nl.float32, buffer=nl.hbm) reduce_tensor = nl.ndarray((on_true_data.shape[0], 1), dtype=nl.float32, buffer=nl.hbm) # Load input data to SBUF predicate = nl.load(predicate_data) on_true = nl.load(on_true_data) on_false = nl.load(on_false_data) # Create destination tensor dst = nl.ndarray(on_true_data.shape, dtype=nl.float32, buffer=nl.sbuf) # Create tensor for reduction results reduce_res = nl.ndarray((on_true_data.shape[0], 1), dtype=nl.float32, buffer=nl.sbuf) # Perform select operation with reduction nisa.select_reduce( dst=dst, predicate=predicate, on_true=on_true, on_false=on_false, reduce_cmd=nisa.reduce_cmd.reset_reduce, reduce_res=reduce_res, reduce_op=nl.max ) # Store results to HBM nl.store(result_tensor, value=dst) nl.store(reduce_tensor, value=reduce_res)
Example 3: Selection with reversed predicate
################################################################## # Example 3: Using select_reduce with reverse_pred option # Reverse the meaning of the predicate ################################################################## # Create output tensor for result result_tensor = nl.ndarray(on_true_data.shape, dtype=nl.float32, buffer=nl.hbm) # Load input data to SBUF predicate = nl.load(predicate_data[...]) on_true = nl.load(on_true_data[...]) # Create destination tensor dst = nl.ndarray(on_true_data.shape, dtype=nl.float32, buffer=nl.sbuf) # Perform select operation with reverse_pred=True # This will select on_true where predicate is FALSE nisa.select_reduce( dst=dst, predicate=predicate, on_true=on_true, on_false=nl.fp32.min, reverse_pred=True # Reverse the meaning of the predicate ) # Store result to HBM nl.store(result_tensor, value=dst)
This document is relevant for: Inf2
, Trn1
, Trn2