This document is relevant for: Inf2, Trn1, Trn2

nki.jit#

nki.jit(func=None, mode='auto', **kwargs)[source]#

This decorator compiles a function to run on NeuronDevices.

This decorator tries to automatically detect the current framework and compile the function as a custom operator of the current framework. To bypass the framework detection logic, you may specify the mode parameter explicitly.

Parameters:
  • func – The function that define the custom op

  • mode – The compilation mode, possible values: “jax”, “torchxla”, “baremetal”, “benchmark”, “simulation” and “auto”

Listing 11 An Example#
from neuronxcc import nki
import neuronxcc.nki.language as nl

@nki.jit
def nki_tensor_tensor_add(a_tensor, b_tensor):
  c_tensor = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype, buffer=nl.shared_hbm)

  a = nl.load(a_tensor)
  b = nl.load(b_tensor)

  c = a + b

  nl.store(c_tensor, c)

  return c_tensor

This document is relevant for: Inf2, Trn1, Trn2