nki.isa.affine_select#
- nki.isa.affine_select(dst, pattern, offset, channel_multiplier, on_true_tile, on_false_value, cmp_op=<ufunc 'equal'>, name=None)[source]#
Select elements between an input tile
on_true_tileand a scalar valueon_false_valueaccording to a boolean predicate tile using GpSimd Engine.The predicate tile is calculated on-the-fly in the engine by evaluating an affine expression element-by-element. The affine expression is defined by a
pattern,offset, andchannel_multiplier, similar tonisa.iota. Thepatternfield is a list of lists in the form of[[step_w, num_w], [step_z, num_z], [step_y, num_y], [step_x, num_x]]. When fewer than 4Dpatternis provided, NKI compiler automatically pads remaining dimensions with size of 1.Given a 4D pattern (padded if needed), the instruction generates a predicate using the following pseudo code:
num_partitions = dst.shape[0] [[step_w, num_w], [step_z, num_z], [step_y, num_y], [step_x, num_x]] = pattern for channel_id in range(num_partitions): for w in range(num_w): for z in range(num_z): for y in range(num_y): for x in range(num_x): affine_value = offset + (channel_id * channel_multiplier) + (w * step_w) + (z * step_z) + (y * step_y) + (x * step_x) predicate = cmp_op(affine_value, 0) # Compare with 0 using cmp_op if predicate: dst[channel_id, w, z, y, x] = on_true_tile[channel_id, w, z, y, x] else: dst[channel_id, w, z, y, x] = on_false_value
The above pseudo code assumes
dsthas the same size in every dimensionx/y/z/wfor simplicity. However, the instruction allows any sizes in the free dimension, as long as the number of elements per partition indstmatches the product:num_w * num_z * num_y * num_x.A common use case for
affine_selectis to apply a causal mask on the attention scores for transformer decoder models.Memory types.
The output
dsttile must be in SBUF. The inputon_true_tilemust also be in SBUF.Data types.
The input
on_true_tileand outputdsttile can be any valid NKI data type (see Supported Data Types for more information). If the data type ofon_true_tilediffers from that ofdst, the input elements inon_true_tile, if selected, are first cast to FP32 before converting to the output data type indst. Theon_false_valuemust be float32, regardless of the input/output tile data types.Layout.
The partition dimension determines the number of active channels for parallel pattern generation and selection. The input tile
on_true_tile, the calculated boolean predicate tile, and the returned output tile must have the same partition dimension size and.Tile size.
The partition dimension size of
dstandon_true_tilemust be the same and must not exceed 128.The number of elements per partition of
dstandon_true_tilemust not exceed the physical size of each SBUF partition.The total number of elements in
patternmust match the number of elements per partition in thedstandon_true_tiletiles.
- Parameters:
dst – the output tile in SBUF to store the selected values
pattern – a list of [step, num] to describe up to 4D tensor sizes and strides for affine expression generation
offset – an int32 offset value to be added to every generated affine value
channel_multiplier – an int32 multiplier to be applied to the channel (partition) ID
on_true_tile – an input tile for selection with a
Truepredicate valueon_false_value – a scalar value for selection with a
Falsepredicate valuecmp_op – comparison operator to use for predicate evaluation (default: np.equal)