This document is relevant for: Inf2
, Trn1
, Trn2
nki.jit#
- nki.jit(func=None, mode='auto', **kwargs)[source]#
This decorator compiles a function to run on NeuronDevices.
This decorator tries to automatically detect the current framework and compile the function as a custom operator of the current framework. To bypass the framework detection logic, you may specify the
mode
parameter explicitly.- Parameters:
func – The function that define the custom op
mode – The compilation mode, possible values: “jax”, “torchxla”, “baremetal”, “benchmark”, “simulation” and “auto”
from neuronxcc import nki import neuronxcc.nki.language as nl @nki.jit def nki_tensor_tensor_add(a_tensor, b_tensor): c_tensor = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype, buffer=nl.shared_hbm) a = nl.load(a_tensor) b = nl.load(b_tensor) c = a + b nl.store(c_tensor, c) return c_tensor
This document is relevant for: Inf2
, Trn1
, Trn2