This document is relevant for: Trn2, Trn3

nki.language#

The nki.language module provides high-level constructs for writing NKI kernels. It includes tensor creation, indexing, type casting, math operations, and loop constructs that the NKI compiler translates into efficient hardware instructions.

Creation operations#

ndarray

Create a new tensor of given shape and dtype on the specified buffer.

zeros

Create a new tensor of given shape and dtype on the specified buffer, filled with zeros.

ones

Create a new tensor of given shape and dtype on the specified buffer, filled with ones.

full

Create a new tensor of given shape and dtype on the specified buffer, filled with initial value.

zeros_like

Create a new tensor of zeros with the same shape and type as a given tensor.

empty_like

Create a new tensor with the same shape and type as a given tensor.

shared_identity_matrix

Create an identity matrix in SBUF with the specified data type.

rand

Create a new tensor of given shape and dtype on the specified buffer, filled with random values.

random_seed

Set the random seed for random number generation.

Tensor operations#

load

Load a tensor from device memory (HBM) into on-chip memory (SBUF).

load_transpose2d

Load a tensor from device memory (HBM) and 2D-transpose the data before storing into on-chip memory (SBUF).

store

Store into a tensor on device memory (HBM) from on-chip memory (SBUF).

copy

Create a copy of the input tile.

matmul

x @ y matrix multiplication of x and y.

transpose

Transposes a 2D tile between its partition and free dimension.

Math operations#

abs

Absolute value of the input, element-wise.

add

Add the inputs, element-wise.

arctan

Inverse tangent of the input, element-wise.

ceil

Ceiling of the input, element-wise.

cos

Cosine of the input, element-wise.

exp

Exponential of the input, element-wise.

floor

Floor of the input, element-wise.

log

Natural logarithm of the input, element-wise.

maximum

Maximum of the inputs, element-wise.

minimum

Minimum of the inputs, element-wise.

multiply

Multiply the inputs, element-wise.

negative

Numerical negative of the input, element-wise.

power

Elements of x raised to powers of y, element-wise.

reciprocal

Reciprocal of the input, element-wise.

rsqrt

Reciprocal of the square-root of the input, element-wise.

sign

Sign of the numbers of the input, element-wise.

sin

Sine of the input, element-wise.

sqrt

Non-negative square-root of the input, element-wise.

square

Square of the input, element-wise.

subtract

Subtract the inputs, element-wise.

tan

Tangent of the input, element-wise.

tanh

Hyperbolic tangent, element-wise.

trunc

Truncated value of the input, element-wise.

Activation and Backpropagation functions#

relu

ReLU activation, element-wise.

sigmoid

Sigmoid activation, element-wise.

silu

SiLU (Swish) activation, element-wise.

silu_dx

Derivative of SiLU activation, element-wise.

gelu

GELU activation, element-wise.

gelu_dx

Derivative of GELU activation, element-wise.

gelu_apprx_sigmoid

GELU approximation using sigmoid, element-wise.

gelu_apprx_sigmoid_dx

Derivative of sigmoid-approximated GELU, element-wise.

gelu_apprx_tanh

GELU approximation using tanh, element-wise.

mish

Mish activation, element-wise.

softplus

Softplus activation, element-wise.

softmax

Softmax activation function on the input, element-wise.

erf

Error function, element-wise.

erf_dx

Derivative of error function, element-wise.

Normalization and Regularization functions#

dropout

Randomly zeroes some of the elements of the input tile given a probability rate.

rms_norm

Apply Root Mean Square Layer Normalization.

Reduction operations#

all

Whether all elements along the specified axis (or axes) evaluate to True.

max

Maximum of elements along the specified axis (or axes) of the input.

mean

Arithmetic mean along the specified axis (or axes) of the input.

min

Minimum of elements along the specified axis (or axes) of the input.

prod

Product of elements along the specified axis (or axes) of the input.

sum

Sum of elements along the specified axis (or axes) of the input.

var

Variance along the specified axis (or axes) of the input.

Comparison operations#

equal

Return (x == y) element-wise.

not_equal

Return (x != y) element-wise.

less

Return (x < y) element-wise.

less_equal

Return (x <= y) element-wise.

greater

Return (x > y) element-wise.

greater_equal

Return (x >= y) element-wise.

Logical operations#

logical_and

Compute the logical AND of two tiles element-wise.

logical_or

Compute the logical OR of two tiles element-wise.

logical_xor

Compute the logical XOR of two tiles element-wise.

logical_not

Compute the logical NOT element-wise.

Bitwise operations#

bitwise_and

Compute the bitwise AND of two tiles element-wise.

bitwise_or

Compute the bitwise OR of two tiles element-wise.

bitwise_xor

Compute the bitwise XOR of two tiles element-wise.

invert

Compute the bitwise NOT element-wise.

left_shift

Left shift the bits of x by y positions element-wise.

right_shift

Right shift the bits of x by y positions element-wise.

Tensor manipulation operations#

broadcast_to

Broadcast a tile to a new shape following numpy broadcasting rules.

ds

Create a dynamic slice for tensor indexing.

expand_dims

Expand the shape of a tile.

Indexing operations#

where

Return elements chosen from x or y depending on condition.

gather_flattened

Gather elements from data tensor using indices after flattening.

Iterators#

affine_range

Create a sequence for fully unrolled loop iteration.

dynamic_range

Create a sequence for dynamic loop iteration.

sequential_range

Create a sequence for fully unrolled loop iteration.

static_range

Create a sequence for fully unrolled loop iteration.

Memory Hierarchy#

psum

Memory region constants for NKI tensors.

sbuf

Memory region constants for NKI tensors.

hbm

Memory region constants for NKI tensors.

private_hbm

Memory region constants for NKI tensors.

shared_hbm

Memory region constants for NKI tensors.

is_psum

Check if buffer is PSUM.

is_sbuf

Check if buffer is SBUF.

is_hbm

Check if buffer is any HBM type.

is_on_chip

Check if buffer is on-chip (SBUF or PSUM).

Others#

device_print

Print a message with a string prefix followed by the value of a tile.

no_reorder

Prevent the scheduler from reordering operations in this region.

program_id

Index of the current SPMD program along the given axis in the launch grid.

num_programs

Number of SPMD programs along the given axes in the launch grid.

program_ndim

Number of dimensions in the SPMD launch grid.

Data Types#

bool_

Boolean (True or False) stored as a byte

int8

8-bit signed integer number

int16

16-bit signed integer number

int32

32-bit signed integer number

uint8

8-bit unsigned integer number

uint16

16-bit unsigned integer number

uint32

32-bit unsigned integer number

float16

16-bit floating-point number

float32

32-bit floating-point number

bfloat16

16-bit floating-point number (1S,8E,7M)

tfloat32

32-bit floating-point number (1S,8E,10M)

float8_e4m3

8-bit floating-point number (1S,4E,3M)

float8_e5m2

8-bit floating-point number (1S,5E,2M)

float8_e4m3fn

no inf, NaN represented by 0bS111'1111

float8_e5m2_x4

4x packed float8_e5m2 elements, custom data type for nki.isa.nc_matmul_mx on NeuronCore-v4

float8_e4m3fn_x4

4x packed float8_e4m3fn elements, custom data type for nki.isa.nc_matmul_mx on NeuronCore-v4

float4_e2m1fn_x4

4x packed float4_e2m1fn elements, custom data type for nki.isa.nc_matmul_mx on NeuronCore-v4

Constants#

tile_size

Hardware tile size constants (pmax, psum_fmax, gemm_stationary_fmax, etc.)

This document is relevant for: Trn2, Trn3