This document is relevant for: Inf2, Trn1, Trn2

nki.isa.max8#

nki.isa.max8(*, src, mask=None, dtype=None, **kwargs)[source]#

Find the 8 largest values in each partition of the source tile.

This instruction reads the input elements, converts them to fp32 internally, and outputs the 8 largest values in descending order for each partition. By default, returns the same dtype as the input tensor.

The source tile can be up to 5-dimensional, while the output tile is always 2-dimensional. The number of elements read per partition must be between 8 and 16,384 inclusive. The output will always contain exactly 8 elements per partition. The source and output must have the same partition dimension size:

  • source: [par_dim, …]

  • output: [par_dim, 8]

Estimated instruction cost:

N engine cycles, where:

  • N is the number of elements per partition in the source tile

Parameters:
  • src – the source tile to find maximum values from

  • mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)

  • dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tile.

Returns:

a 2D tile containing the 8 largest values per partition in descending order with shape [par_dim, 8]

Example:

import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl
from neuronxcc.nki.typing import tensor

##################################################################
# Example 1: Generate tile b of 32 * 128 random floating point values
# and get the 8 largest values in each row:
##################################################################
expr_a = nl.rand((32, 128))
a = nisa.max8(src=expr_a)

a_tensor = nl.ndarray([32, 8], dtype=nl.float32, buffer=nl.shared_hbm)
nl.store(a_tensor, value=a)

This document is relevant for: Inf2, Trn1, Trn2