nki.isa.affine_select#

nki.isa.affine_select(dst, pattern, offset, channel_multiplier, on_true_tile, on_false_value, cmp_op=<ufunc 'equal'>, name=None)[source]#

Select elements between an input tile on_true_tile and a scalar value on_false_value according to a boolean predicate tile using GpSimd Engine.

The predicate tile is calculated on-the-fly in the engine by evaluating an affine expression element-by-element. The affine expression is defined by a pattern, offset, and channel_multiplier, similar to nisa.iota. The pattern field is a list of lists in the form of [[step_w, num_w], [step_z, num_z], [step_y, num_y], [step_x, num_x]]. When fewer than 4D pattern is provided, NKI compiler automatically pads remaining dimensions with size of 1.

Given a 4D pattern (padded if needed), the instruction generates a predicate using the following pseudo code:

num_partitions = dst.shape[0]
[[step_w, num_w], [step_z, num_z], [step_y, num_y], [step_x, num_x]] = pattern

for channel_id in range(num_partitions):
  for w in range(num_w):
    for z in range(num_z):
      for y in range(num_y):
        for x in range(num_x):
          affine_value = offset + (channel_id * channel_multiplier) +
                        (w * step_w) + (z * step_z) + (y * step_y) + (x * step_x)

          predicate = cmp_op(affine_value, 0)  # Compare with 0 using cmp_op

          if predicate:
              dst[channel_id, w, z, y, x] = on_true_tile[channel_id, w, z, y, x]
          else:
              dst[channel_id, w, z, y, x] = on_false_value

The above pseudo code assumes dst has the same size in every dimension x/y/z/w for simplicity. However, the instruction allows any sizes in the free dimension, as long as the number of elements per partition in dst matches the product: num_w * num_z * num_y * num_x.

A common use case for affine_select is to apply a causal mask on the attention scores for transformer decoder models.

Memory types.

The output dst tile must be in SBUF. The input on_true_tile must also be in SBUF.

Data types.

The input on_true_tile and output dst tile can be any valid NKI data type (see Supported Data Types for more information). If the data type of on_true_tile differs from that of dst, the input elements in on_true_tile, if selected, are first cast to FP32 before converting to the output data type in dst. The on_false_value must be float32, regardless of the input/output tile data types.

Layout.

The partition dimension determines the number of active channels for parallel pattern generation and selection. The input tile on_true_tile, the calculated boolean predicate tile, and the returned output tile must have the same partition dimension size and.

Tile size.

  • The partition dimension size of dst and on_true_tile must be the same and must not exceed 128.

  • The number of elements per partition of dst and on_true_tile must not exceed the physical size of each SBUF partition.

  • The total number of elements in pattern must match the number of elements per partition in the dst and on_true_tile tiles.

Parameters:
  • dst – the output tile in SBUF to store the selected values

  • pattern – a list of [step, num] to describe up to 4D tensor sizes and strides for affine expression generation

  • offset – an int32 offset value to be added to every generated affine value

  • channel_multiplier – an int32 multiplier to be applied to the channel (partition) ID

  • on_true_tile – an input tile for selection with a True predicate value

  • on_false_value – a scalar value for selection with a False predicate value

  • cmp_op – comparison operator to use for predicate evaluation (default: np.equal)