This document is relevant for: Inf2, Trn1, Trn1n

Getting Started with NKI#

In this guide, we will implement a simple “Hello World” style NKI kernel and run it on a NeuronDevice (Trainium/Inferentia2 or beyond device). We will showcase how to invoke a NKI kernel standalone through NKI baremetal mode and also through ML frameworks (PyTorch and JAX). Before diving into kernel implementation, let’s make sure you have the correct environment setup for running NKI kernels.

Environment Setup#

You need a Trn1 or Inf2 instance set up on AWS to run NKI kernels on a NeuronDevice. Once logged into the instance, follow steps below to ensure you have all the required packages installed in your Python environment.

NKI is shipped as part of the Neuron compiler package. To make sure you have the latest compiler package, see Setup Guide for an installation guide.

You can verify that NKI is available in your compiler installation by running the following command:

python -c 'import neuronxcc.nki'

This attempts to import the NKI package. It will error out if NKI is not included in your Neuron compiler version or if the Neuron compiler is not installed. The import might take about a minute the first time you run it. Whenever possible, we recommend using local instance NVMe volumes instead of EBS for executable code.

If you intend to run NKI kernels without any ML framework for quick prototyping, you will also need NumPy installed.

To call NKI kernels from PyTorch, you also need to have torch_neuronx installed. For an installation guide, see PyTorch Neuron Setup. You can verify that you have torch_neuronx installed by running the following command:

python -c 'import torch_neuronx'

To call NKI kernels from JAX, you need to have jax_neuronx installed. For an installation guide, see JAX Neuron Setup. You can verify that you have jax_neuronx installed by running the following command:

python -c 'import jax_neuronx'

Implementing your first NKI kernel#

In current NKI release, all input and output tensors must be passed into the kernel as device memory (HBM) tensors on a NeuronDevice. The body of the kernel typically consists of three main phases:

  1. Load the inputs from device memory to on-chip memory (SBUF).

  2. Perform the desired computation.

  3. Store the outputs from on-chip memory to device memory.

For more details on the above terms, see NKI Programming Model.

Below is a small NKI kernel example. In this example, we take two tensors and add them element-wise to produce an output tensor of the same shape.

 1import neuronxcc.nki.language as nl
 2def nki_tensor_add_kernel(a_input, b_input, c_output):
 3    """NKI kernel to compute element-wise addition of two input tensors
 4    """
 5
 6    # Check all input/output tensor shapes are the same for element-wise operation
 7    assert a_input.shape == b_input.shape == c_output.shape
 8
 9    # Check size of the first dimension does not exceed on-chip memory tile size limit,
10    # so that we don't need to tile the input to keep this example simple
11    assert a_input.shape[0] <= nl.tile_size.pmax
12
13    # Load the inputs from device memory to on-chip memory
14    a_tile = nl.load(a_input)
15    b_tile = nl.load(b_input)
16
17    # Specify the computation (in our case: a + b)
18    c_tile = nl.add(a_tile, b_tile)
19
20    # Store the result to c_output from on-chip memory to device memory
21    nl.store(c_output, value=c_tile)

Now let us walk through the above code:

Importing NKI#

We start by importing neuronxcc.nki.language, which implements the NKI language. We will go into more detail regarding the NKI language in NKI Programming Model, but for now you can think of it as a tile-level domain-specific language.

import neuronxcc.nki.language as nl

Defining a kernel#

Next we define the nki_tensor_add_kernel Python function, which contains the NKI kernel code. Note, this kernel function must be decorated appropriately to allow Neuron compiler to recognize this is NKI kernel code and trace it correctly. We will provide more information about what decorator to use later in Running the kernel section.

def nki_tensor_add_kernel(a_input, b_input, c_output):

Note that all NKI kernel inputs and outputs are passed by reference into the function, so there is no explicit return value from the function.

Checking input shapes#

To keep this getting started guide simple, this kernel example expects all input and output tensors have the same shapes for an element-wise addition operation. We further restrict the first dimension of the input/output tensors to not exceed nl.tile_size.pmax == 128. More detailed discussion on tile size limitation is available in NKI Programming Model. Note, all of these restrictions can be lifted with tensor broadcasting/reshape and tensor tiling with loops in NKI. For more kernel examples, check out NKI tutorials.

    # Check all input/output tensor shapes are the same for element-wise operation
    assert a_input.shape == b_input.shape == c_output.shape

    # Check size of the first dimension does not exceed on-chip memory tile size limit,
    # so that we don't need to tile the input to keep this example simple
    assert a_input.shape[0] <= nl.tile_size.pmax

Loading inputs#

Most NKI kernels start by loading inputs from device memory to on-chip memory. We need to do that because computation can only be performed on data in the on-chip memory.

    a_tile = nl.load(a_input)
    b_tile = nl.load(b_input)

Defining the desired computation#

After loading the two input tiles, it is time to define the desired computation. In this case, we perform a simple element-wise addition between two tiles:

    c_tile = nl.add(a_tile, b_tile)

Note that c_tile = a_tile + b_tile will also work, as NKI overloads simple Python operators such as +, -, *, and /. For a complete set of available NKI APIs, refer to NKI API Reference Manual.

Storing outputs#

Every NKI kernel ends by storing its output tiles from the on-chip memory to device memory, where the host can access them:

    nl.store(c_output, value=c_tile)

Running the kernel#

Next, we will cover three unique ways to run the above NKI kernel:

  1. NKI baremetal mode: run NKI kernel with no ML framework involvement

  2. PyTorch: run NKI kernel as a PyTorch operator

  3. JAX: run NKI kernel as a JAX operator

Note

NKI baremetal mode is the most convenient way to prototype and optimize performance a NKI kernel alone. For production ML workloads, we highly recommend invoking NKI kernels through a ML framework (PyTorch or JAX). This allows you to integrate NKI kernels in your regular compute graph to accelerate certain operators (see NKI Kernel as a Framework Custom Operator for details) and leverage the more optimized host-to-device data transfer handling available in ML frameworks.

NKI baremetal#

To run the above nki_tensor_add_kernel kernel in baremetal mode, we can decorate the function with @baremetal as follows:

1from neuronxcc.nki import baremetal
2
3@baremetal
4def nki_tensor_add_kernel(a_input, b_input, c_output):

See nki.baremetal API doc for available input arguments to the decorator. nki.baremetal expects input and output tensors of the NKI kernel to be NumPy arrays. To invoke the kernel, we first initialize the two input tensors a and b and the output tensor c as NumPy arrays. In this scenario, it’s not necessary to zero out the output tensor, as it will be completely overwritten by the result of the addition. However, in some cases, a kernel might overwrite only a part of the output tensor, and the user might want to reset it beforehand to avoid garbage data. Finally, we call the NKI kernel just like any other Python function

 1if __name__ == "__main__":
 2    import numpy as np
 3
 4    a = np.ones((4, 3), dtype=np.float16)
 5    b = np.ones((4, 3), dtype=np.float16)
 6    c = np.zeros((4, 3), dtype=np.float16)
 7
 8    # Run NKI kernel on a NeuronDevice
 9    nki_tensor_add_kernel(a, b, c)
10
11    print(c)

In current NKI release, an output tensor cannot be an input tensor at the same time; therefore, in-out parameters are not supported.

PyTorch#

To run the above nki_tensor_add_kernel kernel using PyTorch, we can decorate the function with @nki_jit as follows:

1from torch_neuronx import nki_jit
2
3@nki_jit
4def nki_tensor_add_kernel(a_input, b_input, c_output):

The kernel caller code is highly similar to NKI baremetal mode, except the input and output tensors must now be initialized as PyTorch device tensors instead.

 1if __name__ == "__main__":
 2    import torch
 3    from torch_xla.core import xla_model as xm
 4
 5    device = xm.xla_device()
 6
 7    a = torch.ones((4, 3), dtype=torch.float16).to(device=device)
 8    b = torch.ones((4, 3), dtype=torch.float16).to(device=device)
 9    c = torch.zeros((4, 3), dtype=torch.float16).to(device=device)
10
11    nki_tensor_add_kernel(a, b, c)
12
13    print(c) # an implicit XLA barrier/mark-step (triggers XLA compilation)

Running the above code for the first time will trigger compilation of the NKI kernel, which might take a few minutes before printing any output. The printed output should be as follows:

tensor([[2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]], device='xla:1', dtype=torch.float16)

JAX#

To run the above nki_tensor_add_kernel kernel using JAX, we can initialize the input/output tensors as JAX tensors and call the kernel directly using a nki_call (imported from neuorn_jax):

 1if __name__ == "__main__":
 2    import jax
 3    import jax.numpy as jnp
 4    from jax_neuronx import nki_call
 5
 6    a = jnp.ones((4, 3), dtype=jnp.float16)
 7    b = jnp.ones((4, 3), dtype=jnp.float16)
 8
 9    c = nki_call(nki_tensor_add_kernel, a, b,
10                out_shape=jax.ShapeDtypeStruct(a.shape, dtype=a.dtype))
11
12    print(c)