This document is relevant for: Inf2, Trn1, Trn2
nki.isa.affine_select#
- nki.isa.affine_select(pred, on_true_tile, on_false_value, *, mask=None, dtype=None, **kwargs)[source]#
Select elements between an input tile
on_true_tileand a scalar valueon_false_valueaccording to a boolean predicate tile using GpSimd Engine. The predicate tile is calculated on-the-fly in the engine by evaluating an affine expression element-by-element as indicated inpred.predmust meet the following requirements:It must not depend on any runtime variables that can’t be resolved at compile-time.
It can’t be multiple masks combined using logical operators such as
&and|.
For a complex predicate that doesn’t meet the above requirements, consider using nl.where.
The input tile
on_true_tile, the calculated boolean predicate tile expressed bypred, and the returned output tile of this instruction must have the same shape. If the predicate value of a given position isTrue, the corresponding output element will take the element fromon_true_tilein the same position. If the predicate value of a given position isFalse, the corresponding output element will take the value ofon_false_value.A common use case for
affine_selectis to apply a causal mask on the attention scores for transformer decoder models.This instruction allows any float or 8-bit/16-bit integer data types for both the input data tile and output tile (see Supported Data Types for more information). The output tile data type is specified using the
dtypefield. Ifdtypeis not specified, the output data type will be the same as the input data type ofdata. However, the data type ofon_false_valuemust be float32, regardless of the input/output tile data types.Estimated instruction cost:
GPSIMD_START + NGpSimd Engine cycles, whereNis the number of elements per partition inon_true_tileandGPSIMD_STARTis the instruction startup overhead on GpSimdE, roughly 150 engine cycles.- Parameters:
pred – an affine expression that defines the boolean predicate
on_true_tile – an input tile for selection with a
Truepredicate valueon_false_value – a scalar value for selection with a
Falsepredicate valuemask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)
dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see NKI Type Promotion for more information);
- Returns:
an output tile with values selected from either
on_true_tileoron_false_valueaccording to the following equation: output[x] = (pred[x] > 0) ? on_true_tile[x] : on_false_value
Example:
import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl ################################################################## # Example 1: Take tile a of shape [128, 128] and replace its # upper triangle with -9984.0; ################################################################## ix, iy = nl.mgrid[0:128, 0:128] a = nl.load(a_tensor[ix, iy]) b = nisa.affine_select(pred=(iy <ix), on_true_tile=a[ix, iy], on_false_value=-9984.0) nl.store(b_tensor[ix, iy], b)
This document is relevant for: Inf2, Trn1, Trn2