nki.language#

Creation operations#

ndarray

Create a new tensor of given shape and dtype on the specified buffer.

zeros

Create a new tensor of given shape and dtype on the specified buffer, filled with zeros.

Tensor manipulation operations#

ds

Construct a dynamic slice for simple tensor indexing.

Iterators#

static_range

Create a sequence of numbers for use as loop iterators in NKI, resulting in a fully unrolled loop.

affine_range

Create a sequence of numbers for use as parallel loop iterators in NKI.

sequential_range

Create a sequence of numbers for use as sequential loop iterators in NKI.

Memory Hierarchy#

psum

PSUM - Only visible to each individual kernel instance in the SPMD grid

sbuf

State Buffer - Only visible to each individual kernel instance in the SPMD grid

hbm

HBM - Alias of private_hbm

private_hbm

HBM - Only visible to each individual kernel instance in the SPMD grid

shared_hbm

Shared HBM - Visible to all kernel instances in the SPMD grid

Others#

program_id

Index of the current SPMD program along the given axis in the launch grid.

num_programs

Number of SPMD programs along the given axes in the launch grid.

program_ndim

Number of dimensions in the SPMD launch grid.

Data Types#

bool_

Boolean (True or False) stored as a byte

uint8

8-bit unsigned integer number

uint16

16-bit unsigned integer number

uint32

32-bit unsigned integer number

int8

8-bit signed integer number

int16

16-bit signed integer number

int32

32-bit signed integer number

float4_e2m1fn_x4

4x packed float4_e2m1fn elements, custom data type for nki.isa.nc_matmul_mx on NeuronCore-v4

float8_e4m3

8-bit floating-point number (1S,4E,3M)

float8_e4m3fn_x4

4x packed float8_e4m3fn elements, custom data type for nki.isa.nc_matmul_mx on NeuronCore-v4

float8_e5m2

8-bit floating-point number (1S,5E,2M)

float8_e5m2_x4

4x packed float8_e5m2 elements, custom data type for nki.isa.nc_matmul_mx on NeuronCore-v4

float16

16-bit floating-point number

bfloat16

16-bit floating-point number (1S,8E,7M)

float32

32-bit floating-point number

tfloat32

32-bit floating-point number (1S,8E,10M)

Constants#

tile_size

Tile size constants.