This document is relevant for: Inf2, Trn1, Trn1n

nki.language#

Memory operations#

load

Load a tensor from device memory (HBM) into on-chip memory (SBUF).

store

Store into a tensor on device memory (HBM) from on-chip memory (SBUF).

load_transpose2d

Load a tensor from device memory (HBM) and 2D-transpose the data before storing into on-chip memory (SBUF).

atomic_rmw

Perform an atomic read-modify-write operation on HBM data dst = op(dst, value)

copy

Create a copy of the src tile.

Creation operations#

ndarray

Create a new tensor of given shape and dtype on the specified buffer.

zeros

Create a new tensor of given shape and dtype on the specified buffer, filled with zeros.

zeros_like

Create a new tensor of zeros with the same shape and type as a given tensor.

ones

Create a new tensor of given shape and dtype on the specified buffer, filled with ones.

full

Create a new tensor of given shape and dtype on the specified buffer, filled with initial value.

rand

Generate a tile of given shape and dtype, filled with random values that are sampled from a uniform distribution between 0 and 1.

random_seed

Sets a seed, specified by user, to the random number generator on HW.

shared_constant

Create a new tensor filled with the data specified by data array.

Math operations#

add

Add the inputs, element-wise.

subtract

Subtract the inputs, element-wise.

multiply

Multiply the inputs, element-wise.

divide

Divide the inputs, element-wise.

power

Elements of x raised to powers of y, element-wise.

maximum

Maximum of the inputs, element-wise.

minimum

Minimum of the inputs, element-wise.

max

Maximum of elements along the specified axis (or axes) of the input.

min

Minimum of elements along the specified axis (or axes) of the input.

mean

Arithmetic mean along the specified axis (or axes) of the input.

var

Variance along the specified axis (or axes) of the input.

sum

Sum of elements along the specified axis (or axes) of the input.

prod

Product of elements along the specified axis (or axes) of the input.

all

Whether all elements along the specified axis (or axes) evaluate to True.

abs

Absolute value of the input, element-wise.

negative

Numerical negative of the input, element-wise.

sign

Sign of the numbers of the input, element-wise.

trunc

Truncated value of the input, element-wise.

floor

Floor of the input, element-wise.

ceil

Ceiling of the input, element-wise.

exp

Exponential of the input, element-wise.

log

Natural logarithm of the input, element-wise.

cos

Cosine of the input, element-wise.

sin

Sine of the input, element-wise.

tanh

Hyperbolic tangent of the input, element-wise.

arctan

Inverse tangent of the input, element-wise.

sqrt

Non-negative square-root of the input, element-wise.

rsqrt

Reciprocal of the square-root of the input, element-wise.

sigmoid

Logistic sigmoid activation function on the input, element-wise.

relu

Rectified Linear Unit activation function on the input, element-wise.

gelu

Gaussian Error Linear Unit activation function on the input, element-wise.

gelu_dx

Derivative of Gaussian Error Linear Unit (gelu) on the input, element-wise.

gelu_apprx_tanh

Gaussian Error Linear Unit activation function on the input, element-wise, with tanh approximation.

erf

Error function of the input, element-wise.

erf_dx

Derivative of the Error function (erf) on the input, element-wise.

softplus

Softplus activation function on the input, element-wise.

mish

Mish activation function on the input, element-wise.

square

Square of the input, element-wise.

softmax

Softmax activation function on the input, element-wise.

rms_norm

Apply Root Mean Square Layer Normalization.

dropout

Randomly zeroes some of the elements of the input tile given a probability rate.

matmul

x @ y matrix multiplication of x and y.

transpose

Transposes a 2D tile between its partition and free dimension.

Bitwise operations#

bitwise_and

Bit-wise AND of the two inputs, element-wise.

bitwise_or

Bit-wise OR of the two inputs, element-wise.

bitwise_xor

Bit-wise XOR of the two inputs, element-wise.

invert

Bit-wise NOT of the input, element-wise.

Logical operations#

equal

Element-wise boolean result of x == y.

not_equal

Element-wise boolean result of x != y.

greater

Element-wise boolean result of x > y.

greater_equal

Element-wise boolean result of x >= y.

less

Element-wise boolean result of x < y.

less_equal

Element-wise boolean result of x <= y.

logical_and

Element-wise boolean result of x AND y.

logical_or

Element-wise boolean result of x OR y.

logical_xor

Element-wise boolean result of x XOR y.

logical_not

Element-wise boolean result of NOT x.

Tensor manipulation operations#

arange

Return contiguous values within a given interval, used for indexing a tensor to define a tile.

mgrid

Same as NumPy mgrid: "An instance which returns a dense (or fleshed out) mesh-grid when indexed, so that each returned argument has the same shape.

expand_dims

Expand the shape of a tile.

Sorting/Searching operations#

where

Return elements chosen from x or y depending on condition.

Collective communication operations#

all_reduce

Apply reduce operation over multiple SPMD programs.

Iterators#

static_range

Create a sequence of numbers for use as loop iterators in NKI, resulting in a fully unrolled loop.

affine_range

Create a sequence of numbers for use as parallel loop iterators in NKI.

sequential_range

Create a sequence of numbers for use as sequential loop iterators in NKI.

Memory Hierarchy#

par_dim

Mark a dimension explicitly as a partition dimension.

psum

PSUM - Only visible to each individual kernel instance in the SPMD grid

sbuf

State Buffer - Only visible to each individual kernel instance in the SPMD grid

hbm

HBM - Alias of private_hbm

private_hbm

HBM - Only visible to each individual kernel instance in the SPMD grid

shared_hbm

Shared HBM - Visible to all kernel instances in the SPMD grid

Others#

program_id

Index of the current SPMD program along the given axis in the launch grid.

num_programs

Number of SPMD programs along the given axes in the launch grid.

program_ndim

Number of dimensions in the SPMD launch grid.

device_print

Print a message with a String prefix followed by the value of a tile x.

loop_reduce

Apply reduce operation over a loop.

Data Types#

tfloat32

32-bit custom floating-point number

bfloat16

16-bit custom floating-point number

float8_e4m3

8-bit custom floating-point number

Constants#

tile_size

Tile size constants.

This document is relevant for: Inf2, Trn1, Trn1n