This document is relevant for: Inf2, Trn1, Trn2

Getting Started with NKI#

In this guide, we will implement a simple “Hello World” style NKI kernel and run it on a NeuronDevice (Trainium/Inferentia2 or beyond device). We will showcase how to invoke a NKI kernel standalone through NKI baremetal mode and also through ML frameworks (PyTorch and JAX). Before diving into kernel implementation, let’s make sure you have the correct environment setup for running NKI kernels.

Environment Setup#

You need a Trn1 or Inf2 instance set up on AWS to run NKI kernels on a NeuronDevice. Once logged into the instance, follow steps below to ensure you have all the required packages installed in your Python environment.

NKI is shipped as part of the Neuron compiler package. To make sure you have the latest compiler package, see Setup Guide for an installation guide.

You can verify that NKI is available in your compiler installation by running the following command:

python -c 'import neuronxcc.nki'

This attempts to import the NKI package. It will error out if NKI is not included in your Neuron compiler version or if the Neuron compiler is not installed. The import might take about a minute the first time you run it. Whenever possible, we recommend using local instance NVMe volumes instead of EBS for executable code.

If you intend to run NKI kernels without any ML framework for quick prototyping, you will also need NumPy installed.

To call NKI kernels from PyTorch, you also need to have torch_neuronx installed. For an installation guide, see PyTorch Neuron Setup. You can verify that you have torch_neuronx installed by running the following command:

python -c 'import torch_neuronx'

To call NKI kernels from JAX, you need to have jax_neuronx installed. For an installation guide, see JAX Neuron Setup. You can verify that you have jax_neuronx installed by running the following command:

python -c 'import jax_neuronx'

Implementing your first NKI kernel#

In current NKI release, all input tensors must be passed into the kernel as device memory (HBM) tensors on a NeuronDevice. Similarly, output tensors returned from the kernel must also reside in device memory. The body of the kernel typically consists of three main phases:

  1. Load the inputs from device memory to on-chip memory (SBUF).

  2. Perform the desired computation.

  3. Store the outputs from on-chip memory to device memory.

For more details on the above terms, see NKI Programming Model.

Below is a small NKI kernel example. In this example, we take two tensors and add them element-wise to produce an output tensor of the same shape.

 1from neuronxcc import nki
 2import neuronxcc.nki.language as nl
 3
 4
 5@nki.jit
 6def nki_tensor_add_kernel(a_input, b_input):
 7
 8    """NKI kernel to compute element-wise addition of two input tensors
 9    """
10
11    # Check all input/output tensor shapes are the same for element-wise operation
12    assert a_input.shape == b_input.shape
13
14    # Check size of the first dimension does not exceed on-chip memory tile size limit,
15    # so that we don't need to tile the input to keep this example simple
16    assert a_input.shape[0] <= nl.tile_size.pmax
17
18    # Load the inputs from device memory to on-chip memory
19    a_tile = nl.load(a_input)
20    b_tile = nl.load(b_input)
21
22    # Specify the computation (in our case: a + b)
23    c_tile = nl.add(a_tile, b_tile)
24
25    # Create a HBM tensor as the kernel output
26    c_output = nl.ndarray(a_input.shape, dtype=a_input.dtype, buffer=nl.shared_hbm)
27
28    # Store the result to c_output from on-chip memory to device memory
29    nl.store(c_output, value=c_tile)
30
31    # Return kernel output as function output
32    return c_output

Now let us walk through the above code:

Importing NKI#

We start by importing neuronxcc.nki which includes function decorators to compile NKI kernels and also neuronxcc.nki.language which implements the NKI language. We will go into more detail regarding the NKI language in NKI Programming Model, but for now you can think of it as a tile-level domain-specific language.

from neuronxcc import nki
import neuronxcc.nki.language as nl

Defining a kernel#

Next we define the nki_tensor_add_kernel Python function, which contains the NKI kernel code. The kernel is decorated with nki.jit, which allows Neuron compiler to recognize this is NKI kernel code and trace it correctly. Input tensors (a_input and b_input) are passed by reference into the kernel, just like any other Python function input parameters.

@nki.jit
def nki_tensor_add_kernel(a_input, b_input):

Checking input shapes#

To keep this getting started guide simple, this kernel example expects all input and output tensors have the same shapes for an element-wise addition operation. We further restrict the first dimension of the input/output tensors to not exceed nl.tile_size.pmax == 128. More detailed discussion on tile size limitation is available in NKI Programming Model. Note, all of these restrictions can be lifted with tensor broadcasting/reshape and tensor tiling with loops in NKI. For more kernel examples, check out NKI tutorials.

# Check all input/output tensor shapes are the same for element-wise operation
assert a_input.shape == b_input.shape

# Check size of the first dimension does not exceed on-chip memory tile size limit,
# so that we don't need to tile the input to keep this example simple
assert a_input.shape[0] <= nl.tile_size.pmax

Loading inputs#

Most NKI kernels start by loading inputs from device memory to on-chip memory. We need to do that because computation can only be performed on data in the on-chip memory.

a_tile = nl.load(a_input)
b_tile = nl.load(b_input)

Defining the desired computation#

After loading the two input tiles, it is time to define the desired computation. In this case, we perform a simple element-wise addition between two tiles:

c_tile = nl.add(a_tile, b_tile)

Note that c_tile = a_tile + b_tile will also work, as NKI overloads simple Python operators such as +, -, *, and /. For a complete set of available NKI APIs, refer to NKI API Reference Manual.

Storing and returning outputs#

To return the output tensor of the kernel, we first declare a NKI tensor c_output in device memory (HBM) and then store the output tile c_tile from on-chip memory to c_output using nl.store. We end the kernel execution by returning c_output using a standard Python return call. This will allow the host to access the output tensor.

# Create a HBM tensor as the kernel output
c_output = nl.ndarray(a_input.shape, dtype=a_input.dtype, buffer=nl.shared_hbm)

# Store the result to c_output from on-chip memory to device memory
nl.store(c_output, value=c_tile)

# Return kernel output as function output
return c_output

Running the kernel#

Next, we will cover three unique ways to run the above NKI kernel on a NeuronDevice:

  1. NKI baremetal: run NKI kernel with no ML framework involvement

  2. PyTorch: run NKI kernel as a PyTorch operator

  3. JAX: run NKI kernel as a JAX operator

All three run modes can call the same kernel function decorated with the nki.jit decorator as discussed above:

1@nki.jit
2def nki_tensor_add_kernel(a_input, b_input):

The nki.jit decorator automatically chooses the correct run mode by checking the incoming tensor type:

  1. NumPy arrays as input: run in NKI baremetal mode

  2. PyTorch tensors as input: run in PyTorch mode

  3. JAX tensors: run in JAX mode

See nki.jit API doc for more details.

Note

NKI baremetal mode is the most convenient way to prototype and optimize performance a NKI kernel alone. For production ML workloads, we highly recommend invoking NKI kernels through a ML framework (PyTorch or JAX). This allows you to integrate NKI kernels in your regular compute graph to accelerate certain operators (see NKI Kernel as a Framework Custom Operator for details) and leverage the more optimized host-to-device data transfer handling available in ML frameworks.

NKI baremetal#

Baremetal mode expects input tensors of the NKI kernel to be NumPy arrays. The kernel also converts its NKI output tensors to NumPy arrays. To invoke the kernel, we first initialize the two input tensors a and b as NumPy arrays. Finally, we call the NKI kernel just like any other Python function:

1import numpy as np
2
3a = np.ones((4, 3), dtype=np.float16)
4b = np.ones((4, 3), dtype=np.float16)
5
6# Run NKI kernel on a NeuronDevice
7c = nki_tensor_add_kernel(a, b)
8
9print(c)

Note

Alternatively, we can decorate the kernel with nki.baremetal or pass the mode parameter to the nki.jit decorator, @nki.jit(mode='baremetal'), to bypass the dynamic mode detection. See nki.baremetal API doc for more available input arguments for the baremetal mode.

PyTorch#

To run the above nki_tensor_add_kernel kernel using PyTorch, we initialize the input and output tensors as PyTorch device tensors instead.

 1import torch
 2from torch_xla.core import xla_model as xm
 3
 4device = xm.xla_device()
 5
 6a = torch.ones((4, 3), dtype=torch.float16).to(device=device)
 7b = torch.ones((4, 3), dtype=torch.float16).to(device=device)
 8
 9c = nki_tensor_add_kernel(a, b)
10
11print(c)  # an implicit XLA barrier/mark-step (triggers XLA compilation)

Running the above code for the first time will trigger compilation of the NKI kernel, which might take a few minutes before printing any output. The printed output should be as follows:

tensor([[2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]], device='xla:1', dtype=torch.float16)

Note

Alternatively, we can pass the mode='torchxla' parameter into the nki.jit decorator to bypass the dynamic mode detection.

JAX#

To run the above nki_tensor_add_kernel kernel using JAX, we initialize the input tensors as JAX tensors:

1import jax.numpy as jnp
2
3a = jnp.ones((4, 3), dtype=jnp.float16)
4b = jnp.ones((4, 3), dtype=jnp.float16)
5
6c = nki_tensor_add_kernel(a, b)
7
8print(c)

Note

Alternatively, we can pass the mode='jax' parameter into the nki.jit decorator to bypass the dynamic mode detection.