How does the NKI Compiler work with the Neuron SDK?#

This topic covers the NKI Compiler and how it integrates with the rest of the Neuron SDK, in particular the Neuron Graph Compiler.

The NKI Compiler#

The NKI Compiler is responsible for compiling NKI kernel functions. A NKI kernel function is any function marked with nki.jit. This decorator identifies a function as a kernel function. You can call this kernel function from within your model.

For each kernel function Neuron runs the NKI compiler to produce an artifact for that kernel function. You can think of this as compiling a single file with a traditional compiler, such as a C++ compiler.

All of the kernel artifacts are managed by the Neuron SDK. Programmers do not need to manage these files themselves. Similar to prior versions of NKI, programmers mark kernel functions with nki.jit—the NKI Compiler will be invoked automatically when this decorator is encountered during compilation.

The Graph Compiler#

The Neuron Graph Compiler (or just the Neuron Compiler) handles the rest of the model, which we refer to as “the compute graph”. The framework, such as PyTorch or Jax, orchestrates the process of building a compute graph from the model definition. When the model includes a call to a NKI kernel function, the NKI compiler will insert a reference to the compiled artifact into the graph. The graph compiler recognizes these references and assembles the final result that can be run on the Trainium Hardware.

Integration#

As described above, both the NKI Compiler and the Neuron Compiler are used to construct the final artifact that can be run on Trainium hardware. The NKI Compiler compiles each NKI kernel function in turn, and the Neuron Compiler compiles the whole model and inserts the NKI kernels based on the references generated by the NKI Compiler.

This insertion of NKI kernels into the graph is done very late in the compilation process. This is different from prior versions of NKI that integrated NKI kernels earlier in the compile process. Insertion later in the process allows the NKI Compiler to provide custom behavior for NKI and give users a more predictable and performant result.

Further reading#