This document is relevant for: Trn2, Trn3
nki.jit#
- nki.jit(fn=None, **kwargs)[source]#
Just-in-time compile a top-level NKI function to run on NeuronDevices.
The returned callable detects the current framework and compiles the function as a custom operator. It detects the current framework by inspecting its arguments:
torch.Tensor: uses PyTorch integration.jax.Array: uses JAX integration.np.ndarray: compiles and executes standalone kernel, without a framework.
You might need to explicitly set the target platform using the
NEURON_PLATFORM_TARGET_OVERRIDEenvironment variable. Supported values:trn1|inf2|gen2trn2|gen3trn3|gen4
The LNC (Logical NeuronCore) degree can be set at the callsite using bracket syntax:
kernel[lnc](args). The default is LNC=1. The LNC value must match theNEURON_LOGICAL_NC_CONFIGenvironment variable set for the Neuron Runtime. Mismatching the two will cause a runtime error. For example, ifNEURON_LOGICAL_NC_CONFIG=1, the kernel must be launched withkernel[1](...)orkernel(...).Returns a
Kernelinstance wrapping the decorated function.
This document is relevant for: Trn2, Trn3