This document is relevant for: Trn2, Trn3
nki.language#
The nki.language module provides high-level constructs for writing NKI kernels.
It includes tensor creation, indexing, type casting, math operations, and loop constructs
that the NKI compiler translates into efficient hardware instructions.
Creation operations#
Create a new tensor of given shape and dtype on the specified buffer. |
|
Create a new tensor of given shape and dtype on the specified buffer, filled with zeros. |
|
Create a new tensor of given shape and dtype on the specified buffer, filled with ones. |
|
Create a new tensor of given shape and dtype on the specified buffer, filled with initial value. |
|
Create a new tensor of zeros with the same shape and type as a given tensor. |
|
Create a new tensor with the same shape and type as a given tensor. |
|
Create an identity matrix in SBUF with the specified data type. |
|
Create a new tensor of given shape and dtype on the specified buffer, filled with random values. |
|
Set the random seed for random number generation. |
Tensor operations#
Load a tensor from device memory (HBM) into on-chip memory (SBUF). |
|
Load a tensor from device memory (HBM) and 2D-transpose the data before storing into on-chip memory (SBUF). |
|
Store into a tensor on device memory (HBM) from on-chip memory (SBUF). |
|
Create a copy of the input tile. |
|
x @ y matrix multiplication of x and y. |
|
Transposes a 2D tile between its partition and free dimension. |
Math operations#
Absolute value of the input, element-wise. |
|
Add the inputs, element-wise. |
|
Inverse tangent of the input, element-wise. |
|
Ceiling of the input, element-wise. |
|
Cosine of the input, element-wise. |
|
Exponential of the input, element-wise. |
|
Floor of the input, element-wise. |
|
Natural logarithm of the input, element-wise. |
|
Maximum of the inputs, element-wise. |
|
Minimum of the inputs, element-wise. |
|
Multiply the inputs, element-wise. |
|
Numerical negative of the input, element-wise. |
|
Elements of x raised to powers of y, element-wise. |
|
Reciprocal of the input, element-wise. |
|
Reciprocal of the square-root of the input, element-wise. |
|
Sign of the numbers of the input, element-wise. |
|
Sine of the input, element-wise. |
|
Non-negative square-root of the input, element-wise. |
|
Square of the input, element-wise. |
|
Subtract the inputs, element-wise. |
|
Tangent of the input, element-wise. |
|
Hyperbolic tangent, element-wise. |
|
Truncated value of the input, element-wise. |
Activation and Backpropagation functions#
ReLU activation, element-wise. |
|
Sigmoid activation, element-wise. |
|
SiLU (Swish) activation, element-wise. |
|
Derivative of SiLU activation, element-wise. |
|
GELU activation, element-wise. |
|
Derivative of GELU activation, element-wise. |
|
GELU approximation using sigmoid, element-wise. |
|
Derivative of sigmoid-approximated GELU, element-wise. |
|
GELU approximation using tanh, element-wise. |
|
Mish activation, element-wise. |
|
Softplus activation, element-wise. |
|
Softmax activation function on the input, element-wise. |
|
Error function, element-wise. |
|
Derivative of error function, element-wise. |
Normalization and Regularization functions#
Randomly zeroes some of the elements of the input tile given a probability rate. |
|
Apply Root Mean Square Layer Normalization. |
Reduction operations#
Whether all elements along the specified axis (or axes) evaluate to True. |
|
Maximum of elements along the specified axis (or axes) of the input. |
|
Arithmetic mean along the specified axis (or axes) of the input. |
|
Minimum of elements along the specified axis (or axes) of the input. |
|
Product of elements along the specified axis (or axes) of the input. |
|
Sum of elements along the specified axis (or axes) of the input. |
|
Variance along the specified axis (or axes) of the input. |
Comparison operations#
Return (x == y) element-wise. |
|
Return (x != y) element-wise. |
|
Return (x < y) element-wise. |
|
Return (x <= y) element-wise. |
|
Return (x > y) element-wise. |
|
Return (x >= y) element-wise. |
Logical operations#
Compute the logical AND of two tiles element-wise. |
|
Compute the logical OR of two tiles element-wise. |
|
Compute the logical XOR of two tiles element-wise. |
|
Compute the logical NOT element-wise. |
Bitwise operations#
Compute the bitwise AND of two tiles element-wise. |
|
Compute the bitwise OR of two tiles element-wise. |
|
Compute the bitwise XOR of two tiles element-wise. |
|
Compute the bitwise NOT element-wise. |
|
Left shift the bits of x by y positions element-wise. |
|
Right shift the bits of x by y positions element-wise. |
Tensor manipulation operations#
Broadcast a tile to a new shape following numpy broadcasting rules. |
|
Create a dynamic slice for tensor indexing. |
|
Expand the shape of a tile. |
Indexing operations#
Return elements chosen from x or y depending on condition. |
|
Gather elements from data tensor using indices after flattening. |
Iterators#
Create a sequence for fully unrolled loop iteration. |
|
Create a sequence for dynamic loop iteration. |
|
Create a sequence for fully unrolled loop iteration. |
|
Create a sequence for fully unrolled loop iteration. |
Memory Hierarchy#
Memory region constants for NKI tensors. |
|
Memory region constants for NKI tensors. |
|
Memory region constants for NKI tensors. |
|
Memory region constants for NKI tensors. |
|
Memory region constants for NKI tensors. |
|
Check if buffer is PSUM. |
|
Check if buffer is SBUF. |
|
Check if buffer is any HBM type. |
|
Check if buffer is on-chip (SBUF or PSUM). |
Others#
Print a message with a string prefix followed by the value of a tile. |
|
Prevent the scheduler from reordering operations in this region. |
|
Index of the current SPMD program along the given axis in the launch grid. |
|
Number of SPMD programs along the given axes in the launch grid. |
|
Number of dimensions in the SPMD launch grid. |
Data Types#
Boolean (True or False) stored as a byte |
|
8-bit signed integer number |
|
16-bit signed integer number |
|
32-bit signed integer number |
|
8-bit unsigned integer number |
|
16-bit unsigned integer number |
|
32-bit unsigned integer number |
|
16-bit floating-point number |
|
32-bit floating-point number |
|
16-bit floating-point number (1S,8E,7M) |
|
32-bit floating-point number (1S,8E,10M) |
|
8-bit floating-point number (1S,4E,3M) |
|
8-bit floating-point number (1S,5E,2M) |
|
no inf, NaN represented by 0bS111'1111 |
|
4x packed float8_e5m2 elements, custom data type for nki.isa.nc_matmul_mx on NeuronCore-v4 |
|
4x packed float8_e4m3fn elements, custom data type for nki.isa.nc_matmul_mx on NeuronCore-v4 |
|
4x packed float4_e2m1fn elements, custom data type for nki.isa.nc_matmul_mx on NeuronCore-v4 |
Constants#
Hardware tile size constants (pmax, psum_fmax, gemm_stationary_fmax, etc.) |
This document is relevant for: Trn2, Trn3