This document is relevant for: Inf2, Trn1, Trn2

nki.isa.nc_match_replace8#

nki.isa.nc_match_replace8(*, data, vals, imm, dst_idx=None, mask=None, dtype=None, **kwargs)[source]#

Replace first occurrence of each value in vals with imm in data using the Vector engine and return the replaced tensor. If dst_idx tile is provided, the indices of the matched values are written to dst_idx.

This instruction reads the input data, replaces the first occurrence of each of the given values (from vals tensor) with the specified immediate constant and, optionally, output indices of matched values to dst_idx. When performing the operation, the free dimensions of both data and vals are flattened. However, these dimensions are preserved in the replaced output tensor and in dst_idx respectively. The partition dimension defines the parallelization boundary. Match, replace, and index generation operations execute independently within each partition.

The data tensor can be up to 5-dimensional, while the vals tensor can be up to 3-dimensional. The vals tensor must have exactly 8 elements per partition. The data tensor must have no more than 16,384 elements per partition. The replaced output will have the same shape as the input data tensor. data and vals must have the same number of partitions. Both input tensors can come from SBUF or PSUM.

Behavior is undefined if vals tensor contains values that are not in the data tensor.

If provided, a mask is applied to the data tensor.

Estimated instruction cost:

min(MIN_II, N) engine cycles, where:

  • N is the number of elements per partition in the data tensor

  • MIN_II is the minimum instruction initiation interval for small input tiles. MIN_II is roughly 64 engine cycles.

NumPy equivalent:

# Let's assume we work with NumPy, and ``data``, ``vals`` are 2-dimensional arrays
# (with shape[0] being the partition axis) and imm is a constant float32 value.

import numpy as np

# Get original shapes
data_shape = data.shape
vals_shape = vals.shape

# Reshape to 2D while preserving first dimension
data_2d = data.reshape(data_shape[0], -1)
vals_2d = vals.reshape(vals_shape[0], -1)

# Initialize output array for indices
indices = np.zeros(vals_2d.shape, dtype=np.uint32)

for i in range(data_2d.shape[0]):
  for j in range(vals_2d.shape[1]):
    val = vals_2d[i, j]
    # Find first occurrence of val in data_2d[i, :]
    matches = np.where(data_2d[i, :] == val)[0]
    if matches.size > 0:
      indices[i, j] = matches[0]  # Take first match
      data_2d[i, matches[0]] = imm

output = data_2d.reshape(data.shape)
indices = indices.reshape(vals.shape) # Computed only if ``dst_idx`` is specified
Parameters:
  • data – the data tensor to modify

  • dst_idx – (optional) the destination tile to write flattened indices of matched values

  • vals – tensor containing the 8 values per partition to replace

  • imm – float32 constant to replace matched values with

  • mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)

  • dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tile.

Returns:

the modified data tensor

Example:

import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl
import neuronxcc.nki.typing as nt


##################################################################
# Example 1: Generate tile a of random floating point values,
# get the 8 largest values in each row, then replace their first
# occurrences with -inf:
##################################################################
N = 4
M = 16
data_tile = nl.rand((N, M))
max_vals = nisa.max8(src=data_tile)

result = nisa.nc_match_replace8(data=data_tile[:, :], vals=max_vals, imm=float('-inf'))
result_tensor = nl.ndarray([N, M], dtype=nl.float32, buffer=nl.shared_hbm)
nl.store(result_tensor, value=result)
import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl
import neuronxcc.nki.typing as nt


##################################################################
# Example 2: Read the 8 largest values in each row of the tensor,
# replace the first occurrence with imm, write indices, and return
# the replaced output.
##################################################################
n, m = in_tensor.shape

dst_idx = nl.ndarray((n, 8), dtype=idx_tensor.dtype)

ix, iy = nl.mgrid[0:n, 0:8]

inp_tile: nt.tensor[n, m] = nl.load(in_tensor)
max_vals: nt.tensor[n, 8] = nisa.max8(src=inp_tile)

out_tile = nisa.nc_match_replace8(
  dst_idx=dst_idx[ix, iy], data=inp_tile[:, :], vals=max_vals, imm=imm
)
import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl
import neuronxcc.nki.typing as nt


##################################################################
# Example 3: Read the 8 largest values in each row of the tensor,
# after applying the specified mask, replace the first occurrence
# with imm, write indices, and return the replaced output.
##################################################################
n, m = in_tensor.shape

idx_tile = nisa.memset(shape=(n, 8), value=0, dtype=nl.uint32)

ix, iy = nl.mgrid[0:n, 0:m]
inp_tile: nt.tensor[n, m] = nl.load(in_tensor)
max_vals: nt.tensor[n, 8] = nisa.max8(src=inp_tile[ix, iy], mask=(ix < n //2 and iy < m//2))

out_tile = nisa.nc_match_replace8(
  dst_idx=idx_tile[:, :],
  data=inp_tile[ix, iy],
  vals=max_vals,
  imm=imm,
  mask=(ix < n // 2 and iy < m // 2),  # mask applies to `data`
)
import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl
import neuronxcc.nki.typing as nt


##################################################################
# Example 4: Read the 8 largest values in each row of the tensor,
# replace the first occurrence with 0.0, write indices, and return 
# the replaced output.
##################################################################
n, b, m = data_tensor.shape

n, b, m = data_tensor.shape

out_tensor = nl.ndarray([n, b, m], dtype=data_tensor.dtype, buffer=nl.hbm)
idx_tensor = nl.ndarray([n, 8], dtype=nl.uint32, buffer=nl.hbm)

imm = 0.0
idx_tile = nisa.memset(shape=(n, 8), value=0, dtype=nl.uint32)
out_tile = nisa.memset(shape=(n, b, m), value=0, dtype=data_tensor.dtype)

iq, ir, iw = nl.mgrid[0:n, 0:b, 0:m]
ip, io = nl.mgrid[0:n, 0:8]

inp_tile = nl.load(data_tensor[iq, ir, iw])
max_vals: nt.tensor[n, 8] = nisa.max8(src=inp_tile)

out_tile[iq, ir, iw] = nisa.nc_match_replace8(
  dst_idx=idx_tile[ip, io],
  data=inp_tile[iq, ir, iw],
  vals=max_vals[ip, io],
  imm=imm,
)

import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl
import neuronxcc.nki.typing as nt


##################################################################
# Example 5: Read the 8 largest values in each row of the tensor,
# replace the first occurrence with 0.0 in-place and write indices.
##################################################################
n, b, m = data_tensor.shape

n, b, m = data_tensor.shape

out_tensor = nl.ndarray([n, b, m], dtype=data_tensor.dtype, buffer=nl.hbm)
idx_tensor = nl.ndarray([n, 8], dtype=nl.uint32, buffer=nl.hbm)

imm = 0.0
idx_tile = nisa.memset(shape=(n, 8), value=0, dtype=nl.uint32)

iq, ir, iw = nl.mgrid[0:n, 0:b, 0:m]
ip, io = nl.mgrid[0:n, 0:8]

inp_tile = nl.load(data_tensor[iq, ir, iw])
max_vals: nt.tensor[n, 8] = nisa.max8(src=inp_tile)

inp_tile[iq, ir, iw] = nisa.nc_match_replace8(
  dst_idx=idx_tile[ip, io],
  data=inp_tile[iq, ir, iw],
  vals=max_vals[ip, io],
  imm=imm,
)

This document is relevant for: Inf2, Trn1, Trn2