This document is relevant for: Inf2
, Trn1
, Trn2
nki.isa.max8#
- nki.isa.max8(*, src, mask=None, dtype=None, **kwargs)[source]#
Find the 8 largest values in each partition of the source tile.
This instruction reads the input elements, converts them to fp32 internally, and outputs the 8 largest values in descending order for each partition. By default, returns the same dtype as the input tensor.
The source tile can be up to 5-dimensional, while the output tile is always 2-dimensional. The number of elements read per partition must be between 8 and 16,384 inclusive. The output will always contain exactly 8 elements per partition. The source and output must have the same partition dimension size:
source: [par_dim, …]
output: [par_dim, 8]
Estimated instruction cost:
N
engine cycles, where:N
is the number of elements per partition in the source tile
- Parameters:
src – the source tile to find maximum values from
mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)
dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tile.
- Returns:
a 2D tile containing the 8 largest values per partition in descending order with shape [par_dim, 8]
Example:
import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl from neuronxcc.nki.typing import tensor ################################################################## # Example 1: Generate tile b of 32 * 128 random floating point values # and get the 8 largest values in each row: ################################################################## expr_a = nl.rand((32, 128)) a = nisa.max8(src=expr_a) a_tensor = nl.ndarray([32, 8], dtype=nl.float32, buffer=nl.shared_hbm) nl.store(a_tensor, value=a)
This document is relevant for: Inf2
, Trn1
, Trn2