Source code for nki

"""Auto-generated stub file"""
from enum import Enum
import nki.language as nl
import ml_dtypes

[docs]def jit(func=None, mode="auto", **kwargs): r""" This decorator compiles a top-level NKI function to run on NeuronDevices. This decorator tries to automatically detect the current framework and compile the function as a custom operator. To bypass the framework detection logic, you can specify the ``mode`` parameter explicitly. You might need to explicitly set the target platform using the ``NEURON_PLATFORM_TARGET_OVERRIDE`` environment variable. Supported values are "trn1"/"gen2", "trn2"/"gen3", and "trn3"/"gen4". :param func: Function that defines the custom operation. :param mode: Compilation mode. Supported values are "jax", "torchxla", and "auto". (Default: "auto".) .. code-block:: python :caption: Writing an addition kernel using ``@nki.jit`` @nki.jit() def nki_tensor_add_kernel(a_input, b_input): # Check both input tensor shapes are the same for element-wise operation. assert a_input.shape == b_input.shape # Check the first dimension's size to ensure it does not exceed on-chip # memory tile size, since this simple kernel does not tile inputs. assert a_input.shape[0] <= nl.tile_size.pmax # Allocate space for the input tensors in SBUF and copy the inputs from HBM # to SBUF with DMA copy. a_tile = nl.ndarray(dtype=a_input.dtype, shape=a_input.shape, buffer=nl.sbuf) nisa.dma_copy(dst=a_tile, src=a_input) b_tile = nl.ndarray(dtype=b_input.dtype, shape=b_input.shape, buffer=nl.sbuf) nisa.dma_copy(dst=b_tile, src=b_input) # Allocate space for the result and use tensor_tensor to perform # element-wise addition. Note: the first argument of 'tensor_tensor' # is the destination tensor. c_tile = nl.ndarray(dtype=a_input.dtype, shape=a_input.shape, buffer=nl.sbuf) nisa.tensor_tensor(dst=c_tile, data1=a_tile, data2=b_tile, op=nl.add) # Create a tensor in HBM and copy the result into HBM. c_output = nl.ndarray(dtype=a_input.dtype, shape=a_input.shape, buffer=nl.hbm) nisa.dma_copy(dst=c_output, src=c_tile) # Return kernel output as function output. return c_output """ ...
[docs]def simulate(kernel): """Create a CPU-simulated version of an NKI kernel. .. warning:: This API is experimental and may change in future releases. See :ref:`nki-simulate` for full documentation including target platform selection, precise floating-point mode, debugging, and known limitations. Example: .. code-block:: python @nki.jit def my_kernel(a, b): ... # Explicit simulation result = nki.simulate(my_kernel)(a_np, b_np) # With LNC2 result = nki.simulate(my_kernel[2])(a_np, b_np) Args: kernel: NKI kernel function, typically decorated with ``@nki.jit``. If a plain function is passed, it is automatically wrapped. Returns: A callable that, when invoked with NumPy arrays, executes the kernel on CPU and returns NumPy array results. """ ...