This document is relevant for: Trn2, Trn3

nki.jit#

nki.jit(fn=None, **kwargs)[source]#

Just-in-time compile a top-level NKI function to run on NeuronDevices.

The returned callable detects the current framework and compiles the function as a custom operator. It detects the current framework by inspecting its arguments:

  • torch.Tensor: uses PyTorch integration.

  • jax.Array: uses JAX integration.

  • np.ndarray: compiles and executes standalone kernel, without a framework.

You might need to explicitly set the target platform using the NEURON_PLATFORM_TARGET_OVERRIDE environment variable. Supported values:

  • trn1|inf2|gen2

  • trn2|gen3

  • trn3|gen4

The LNC (Logical NeuronCore) degree can be set at the callsite using bracket syntax: kernel[lnc](args). The default is LNC=1. The LNC value must match the NEURON_LOGICAL_NC_CONFIG environment variable set for the Neuron Runtime. Mismatching the two will cause a runtime error. For example, if NEURON_LOGICAL_NC_CONFIG=1, the kernel must be launched with kernel[1](...) or kernel(...).

Returns a Kernel instance wrapping the decorated function.

This document is relevant for: Trn2, Trn3