This document is relevant for: Inf2, Trn1, Trn2

nki.isa.select_reduce#

nki.isa.select_reduce(*, dst, predicate, on_true, on_false, reduce_res=None, reduce_cmd=reduce_cmd.idle, reduce_op=<function amax>, reverse_pred=False, mask=None, dtype=None, **kwargs)[source]#

Selectively copy elements from either on_true or on_false to the destination tile based on a predicate using Vector Engine, with optional reduction (max).

The operation can be expressed in NumPy as:

# Select:
predicate = ~predicate if reverse_pred else predicate
result = np.where(predicate, on_true, on_false)

# With Reduce:
reduction_result = np.max(result, axis=1, keepdims=True)

Memory constraints:

  • Both on_true and predicate are permitted to be in SBUF

  • Either on_true or predicate may be in PSUM, but not both simultaneously

  • The destination dst can be in either SBUF or PSUM

Shape and data type constraints:

  • on_true, dst, and predicate must have identical shapes (same number of partitions and elements per partition)

  • on_true can be any supported dtype except tfloat32, int32, uint32

  • on_false dtype must be float32 if on_false is a scalar.

  • on_false has to be either scalar or vector of shape (on_true.shape[0], 1)

  • predicate dtype can be any supported integer type int8, uint8, int16, uint16

  • reduce_res must be a vector of shape (on_true.shape[0], 1)

  • reduce_res dtype must of float type

  • reduce_op only supports max

Behavior:

  • Where predicate is True: The corresponding elements from on_true are copied to dst

  • Where predicate is False: The corresponding elements from on_false are copied to dst

  • When reduction is enabled, the max value from each partition of the result is computed and stored in reduce_res

Accumulator behavior:

The Vector Engine maintains internal accumulator registers that can be controlled via the reduce_cmd parameter:

  • nisa.reduce_cmd.reset_reduce: Reset accumulators to -inf, then accumulate the current results

  • nisa.reduce_cmd.reduce: Continue accumulating without resetting (useful for multi-step reductions)

  • nisa.reduce_cmd.idle: No accumulation performed (default)

Note

Even when reduce_cmd is set to idle, the accumulator state may still be modified. Always use reset_reduce after any operations that ran with idle mode to ensure consistent behavior.

Note

The accumulator registers are shared for other Vector Engine accumulation instructions such nki.isa.range_select

Parameters:
  • dst – The destination tile to write the selected values to

  • predicate – Tile that determines which value to select (on_true or on_false)

  • on_true – Tile to select from when predicate is True

  • on_false – Value to use when predicate is False, can be a scalar value or a vector tile of (on_true.shape[0], 1)

  • reduce_res – (optional) Tile to store reduction results, must have shape (on_true.shape[0], 1)

  • reduce_cmd – (optional) Control accumulator behavior using nisa.reduce_cmd values, defaults to idle

  • reduce_op – (optional) Reduction operator to apply (only np.max is supported)

  • reverse_pred – (optional) Reverse the meaning of the predicate condition, defaults to False

  • mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)

  • dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tile.

Example 1: Basic selection

import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl

##################################################################
# Example 1: Basic usage of select_reduce
# Create source data, predicate, and destination tensors
##################################################################
# Create output tensor for result
result_tensor = nl.ndarray(on_true_data.shape, dtype=nl.float32, buffer=nl.hbm)

# Load input data to SBUF
predicate = nl.load(predicate_data[...])
on_true = nl.load(on_true_data[...])

# Create destination tensor
dst = nl.ndarray(on_true_data.shape, dtype=nl.float32, buffer=nl.sbuf)

# Perform select operation - copy from on_true where predicate is true
# and set to fp32.min where predicate is false
nisa.select_reduce(
    dst=dst,
    predicate=predicate,
    on_true=on_true,
    on_false=nl.fp32.min,
)

# Store result to HBM
nl.store(result_tensor, value=dst)

Example 2: Selection with reduction

##################################################################
# Example 2: Using select_reduce with reduction
# Perform selection and compute max reduction per partition
##################################################################
# Create output tensors for results
result_tensor = nl.ndarray(on_true_data.shape, dtype=nl.float32, buffer=nl.hbm)
reduce_tensor = nl.ndarray((on_true_data.shape[0], 1), dtype=nl.float32, buffer=nl.hbm)

# Load input data to SBUF
predicate = nl.load(predicate_data)
on_true = nl.load(on_true_data)
on_false = nl.load(on_false_data)

# Create destination tensor
dst = nl.ndarray(on_true_data.shape, dtype=nl.float32, buffer=nl.sbuf)

# Create tensor for reduction results
reduce_res = nl.ndarray((on_true_data.shape[0], 1), dtype=nl.float32, buffer=nl.sbuf)

# Perform select operation with reduction
nisa.select_reduce(
    dst=dst,
    predicate=predicate,
    on_true=on_true,
    on_false=on_false,
    reduce_cmd=nisa.reduce_cmd.reset_reduce,
    reduce_res=reduce_res,
    reduce_op=nl.max
)

# Store results to HBM
nl.store(result_tensor, value=dst)
nl.store(reduce_tensor, value=reduce_res)

Example 3: Selection with reversed predicate

##################################################################
# Example 3: Using select_reduce with reverse_pred option
# Reverse the meaning of the predicate
##################################################################
# Create output tensor for result
result_tensor = nl.ndarray(on_true_data.shape, dtype=nl.float32, buffer=nl.hbm)

# Load input data to SBUF
predicate = nl.load(predicate_data[...])
on_true = nl.load(on_true_data[...])

# Create destination tensor
dst = nl.ndarray(on_true_data.shape, dtype=nl.float32, buffer=nl.sbuf)

# Perform select operation with reverse_pred=True
# This will select on_true where predicate is FALSE
nisa.select_reduce(
    dst=dst,
    predicate=predicate,
    on_true=on_true,
    on_false=nl.fp32.min,
    reverse_pred=True  # Reverse the meaning of the predicate
)

# Store result to HBM
nl.store(result_tensor, value=dst)

This document is relevant for: Inf2, Trn1, Trn2