Source code for nki

"""Auto-generated stub file"""
from enum import Enum
import nki.language as nl
import ml_dtypes

[docs]def jit(func=None, mode="auto", **kwargs): r""" This decorator compiles a top-level NKI function to run on NeuronDevices. This decorator tries to automatically detect the current framework and compile the function as a custom operator. To bypass the framework detection logic, you can specify the ``mode`` parameter explicitly. You might need to explicitly set the target platform using the ``NEURON_PLATFORM_TARGET_OVERRIDE`` environment variable. Supported values are "trn1"/"gen2", "trn2"/"gen3", and "trn3"/"gen4". :param func: Function that defines the custom operation. :param mode: Compilation mode. Supported values are "jax", "torchxla", and "auto". (Default: "auto".) .. code-block:: python :caption: Writing an addition kernel using ``@nki.jit`` @nki.jit() def nki_tensor_add_kernel(a_input, b_input): # Check both input tensor shapes are the same for element-wise operation. assert a_input.shape == b_input.shape # Check the first dimension's size to ensure it does not exceed on-chip # memory tile size, since this simple kernel does not tile inputs. assert a_input.shape[0] <= nl.tile_size.pmax # Allocate space for the input tensors in SBUF and copy the inputs from HBM # to SBUF with DMA copy. a_tile = nl.ndarray(dtype=a_input.dtype, shape=a_input.shape, buffer=nl.sbuf) nisa.dma_copy(dst=a_tile, src=a_input) b_tile = nl.ndarray(dtype=b_input.dtype, shape=b_input.shape, buffer=nl.sbuf) nisa.dma_copy(dst=b_tile, src=b_input) # Allocate space for the result and use tensor_tensor to perform # element-wise addition. Note: the first argument of 'tensor_tensor' # is the destination tensor. c_tile = nl.ndarray(dtype=a_input.dtype, shape=a_input.shape, buffer=nl.sbuf) nisa.tensor_tensor(dst=c_tile, data1=a_tile, data2=b_tile, op=nl.add) # Create a tensor in HBM and copy the result into HBM. c_output = nl.ndarray(dtype=a_input.dtype, shape=a_input.shape, buffer=nl.hbm) nisa.dma_copy(dst=c_output, src=c_tile) # Return kernel output as function output. return c_output """ ...