NKI Kernel as a Framework Custom Operator#

This document demonstrates how to insert a NKI kernel as a custom operator into a PyTorch or JAX model using simple code examples.

Using NKI kernels#

To register a NKI kernel registration, you need to call a decorated NKI function.

Let’s examine a guiding example below where we randomly initialize two inputs, add them together, and then multiply the result by the two input tensors element-wise. This effectively calculates: a * b * (a + b).

We define a common NKI kernel for addition. This is a tiled variation of the addition kernel from Quickstart: Build and Run a Kernel.

 1import nki as nki
 2import nki.language as nl
 3import nki.isa as nisa
 4
 5@nki.jit(platform_target="trn1")
 6def nki_tensor_add(a_input, b_input):
 7  """NKI kernel to compute element-wise addition of two input tensors
 8
 9  This kernel assumes strict input/output sizes can be uniformly tiled to [128,512]
10
11  Args:
12      a_input: a first input tensor
13      b_input: a second input tensor
14
15  Returns:
16      c_output: an output tensor
17  """
18  # Create output tensor shared between all SPMD instances as 
19  # result tensor (uninitialized)
20  c_output = nl.ndarray(a_input.shape, dtype=a_input.dtype, buffer=nl.shared_hbm)
21
22  # Extract the dimensions for the a_input shape.
23  M, N = a_input.shape
24
25  # Set the tile dimensions, while the TILE_N is not, strictly speaking, limited to 
26  # 512 for the additiona operation, we stick with this size for simplicity.
27  TILE_M = 128
28  TILE_N = 512
29
30  # Check the input sizes match and match the tilable constraint.
31  assert a_input.shape == b_input.shape, \
32    f"Expected shaps {a_input.shape} and {b_input.shape} to match"
33  assert a_input.dtype == b_input.dtype, \
34    f"Expected data types {a_input.dtype} and {b_input.dtype} to match"
35  assert M % TILE_M == 0, \
36    f"Expected partition dimention ({M}) to be divisble by {TILE_M}"
37  assert N % TILE_N == 0, \
38    f"Expected partition dimention ({N}) to be divisble by {TILE_N}"
39
40  # Lop over each tile, load the tile, do the addition, and save it back to HBM.
41  for m in nl.affine_range(M // TILE_M):
42    for n in nl.affine_range(N // TILE_N):
43      # Allocte space for the a_tile and b_tile in sbuf (uninitialized)
44      a_tile = nl.ndarray(shape=(TILE_M, TILE_N), dtype=a_input.dtype, buffer=nl.sbuf)
45      b_tile = nl.ndarray(shape=(TILE_M, TILE_N), dtype=b_input.dtype, buffer=nl.sbuf)
46
47      # Load the a_tile and b_tile from HBM into SBUF.
48      nisa.dma_copy(dst=a_tile,
49                    src=a_input[m * TILE_M:(m + 1) * TILE_M,
50                                n * TILE_N:(n + 1) * TILE_N])
51      nisa.dma_copy(dst=b_tile,
52                    src=b_input[m * TILE_M:(m + 1) * TILE_M,
53                                n * TILE_N:(n + 1) * TILE_N])
54
55      # Allocate space for the c_tile in sbuf.
56      c_tile = nl.ndarray(shape=(TILE_M, TILE_N), dtype=a_input.dtype, buffer=nl.sbuf)
57
58      # Perform the addition using the element-wise tensor_tensor instruction.
59      nisa.tensor_tensor(dst=c_tile, data1=a_tile, data2=b_tile, op=nl.add)
60
61      # Copy the result to the output tensor.
62      nisa.dma_copy(dst=c_output[m * TILE_M:(m + 1) * TILE_M,
63                                 n * TILE_N:(n + 1) * TILE_N],
64                    src=c_tile)
65
66  # Transfer the ownership of `c_output` to the caller
67  return c_output

PyTorch#

We can perform (a + b) * a * b using native PyTorch code.

import torch
from torch_xla.core import xla_model as xm

device = xm.xla_device()

a = torch.randn(256, 1024, dtype=torch.float32).to(device)
b = torch.randn(256, 1024, dtype=torch.float32).to(device)
c = a + b
out = a * b * c

print(out)

Now let’s replace the tensor addition (c = a + b) with a NKI kernel. To do this we replace the + operator with a call to the NKI kernel caller (nki_tensor_add), and everything else works as before.

device = xm.xla_device()
a = torch.randn(256, 1024, dtype=torch.float32).to(device)
b = torch.randn(256, 1024, dtype=torch.float32).to(device)
c = nki_tensor_add(a, b) # calling a NKI kernel, instead of the built-in torch op
out = a * b * c
print(out)

To understand what happens under the hood when we compile the above code, we can print HLO IR graph generated by XLA by setting the NEURON_FRAMEWORK_DEBUG environment variable, which preserves the HLO in binary form, and the XLA_SAVE_TENSORS_FILE, which presents a textual representation of the HLO. For example, you may add the following lines to your code:

import os
os.environ['NEURON_FRAMEWORK_DEBUG'] = "1"
os.environ["XLA_SAVE_TENSORS_FILE"] = "example1.pbtxt"

A example1.pbtxt.0 file is then written in your run directory that has the corresponding human-readable HLO IR.

Let’s examine the XLA output of this example. In line #14 we can identify that the tensor addition is now mapped to an HLO xla::_op_<locals>CallImpl instruction, representing the custom call. The output of that xla::_op_<locals>CallImpl is then consumed by the next instruction in line #15 as usual.

 1 [ScheduleSyncTensorsGraph]
 2 TensorsGraphInfo:
 3   _str_intern (/home/ec2-user/pytorch-klir/lib/python3.10/site-packages/torch/_tensor_str.py:462)
 4   _str (/home/ec2-user/pytorch-klir/lib/python3.10/site-packages/torch/_tensor_str.py:726)
 5   __repr__ (/home/ec2-user/pytorch-klir/lib/python3.10/site-packages/torch/_tensor.py:590)
 6   <module> (/home/ec2-user/private-aws-neuron-sdk-staging/nki/examples/tensor_addition/t2.py:14)
 7
 8 Root Hashes: (181deae9d76fbfbf2fe0e040179f9da8)
 9
10 ## BEGIN_GRAPH
11 IR {
12   %0 = f32[256,1024]{1,0} xla::device_data(), xla_shape=f32[256,1024]{1,0}
13   %1 = f32[256,1024]{1,0} xla::device_data(), xla_shape=f32[256,1024]{1,0}
14   %2 = (f32[256,1024]{1,0}) xla::_op_<locals>CallImpl(%1, %0), xla_shape=(f32[256,1024]{1,0})
15   %3 = f32[256,1024]{1,0} aten::mul(%1, %0), xla_shape=f32[256,1024]{1,0}
16   %4 = f32[256,1024]{1,0} aten::mul(%3, %2), xla_shape=f32[256,1024]{1,0}, ROOT=0
17 }
18
19 Graph Hash: f518d5bd723cb9d6f9482b42b33105e1
20
21 ## END_GRAPH

The Neuron compiler replaces the above custom call with the corresponding NKI kernel implementation while optimizing the rest of the compute graph as usual. At the end of the compilation process, a single compiled binary NEFF file is generated representing the entire graph including the NKI kernel. For more information about NEFF files, see Neuron Compiler.

JAX#

We can perform (a + b) * a * b using native JAX code.

import jax
import jax.numpy as jnp

@jax.jit
def jax_customop_tutorial(a, b):
   c = a + b
   out = a * b * c
   return out

seed = jax.random.PRNGKey(0)
seed_a, seed_b = jax.random.split(seed)
a = jax.random.normal(seed_a, (256, 1024), dtype=jnp.float32)
b = jax.random.normal(seed_b, (256, 1024), dtype=jnp.float32)

print(jax_customop_tutorial(a, b))

Similar to the PyTorch example above, let’s replace the tensor addition (c = a + b) with the addition NKI kernel. To do this we replace the + operator with a call to the NKI kernel caller (nki_tensor_add), and everything else works as before.

import jax
import jax.numpy as jnp

@jax.jit
def jax_customop_tutorial(a, b):
   c = nki_tensor_add(a, b) # calling a NKI kernel, instead of the built-in jax op
   out = a * b * c
   return out

seed = jax.random.PRNGKey(0)
seed_a, seed_b = jax.random.split(seed)
a = jax.random.normal(seed_a, (256, 1024), dtype=jnp.float32)
b = jax.random.normal(seed_b, (256, 1024), dtype=jnp.float32)
print(jax_customop_tutorial(a, b))

To understand what happens under the hood when we compile the above code, we can print the HLO IR graph by adding the following snippet to your code:

print(jax.jit(jax_customop_tutorial)
   .lower(a, b)
   .compile()
   .runtime_executable()
   .hlo_modules()[0].to_string()
)

Let’s examine the XLA output of this example. In line #8 we can identify that the tensor addition is now mapped to an HLO custom-call instruction, similar to PyTorch. The output of that custom-call is then consumed by the next instruction in line #9 as usual.

 1HloModule jit_jax_customop_tutorial, entry_computation_layout={(f32[256,1024]{1,0}, f32[256,1024]{1,0})->(f32[256,1024]{1,0})}, allow_spmd_sharding_propagation_to_parameters={}, allow_spmd_sharding_propagation_to_output={true}
 2
 3ENTRY %main.12 (Arg_0.1: f32[256,1024], Arg_1.2: f32[256,1024]) -> (f32[256,1024]) {
 4  %Arg_0.1 = f32[256,1024]{1,0} parameter(0), metadata={op_name="a"}
 5  %Arg_1.2 = f32[256,1024]{1,0} parameter(1), metadata={op_name="b"}
 6  %multiply.0 = f32[256,1024]{1,0} multiply(%Arg_0.1, %Arg_1.2), metadata={op_name="jit(jax_customop_tutorial)/jit(main)/jit(jax_customop_tutorial)/mul" source_file="/home/ec2-user/private-aws-neuron-sdk-staging/nki/examples/tensor_addition/t4.py" source_line=9}
 7  %constant.0 = s8[128,128]{1,0} constant({...})
 8  %custom-call.0 = f32[256,1024]{1,0} custom-call(%Arg_0.1, %Arg_1.2, %constant.0), custom_call_target="AwsNeuronCustomNativeKernel", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(jax_customop_tutorial)/jit(main)/jit(jax_customop_tutorial)/nki_call" source_file="/home/ec2-user/jax-klir/lib/python3.10/site-packages/nki/_jax.py" source_line=64}, backend_config="eyJrZXJuZWxfdmVyc2lvbiI6IDEsICJrbGlyX2JpbmFyeSI6IHsiYmluYXJ5IjogIi90bXAvbmtpX3RlbnNvcl9hZGRncmV6aGF4eS5rbGlyIiwgImlucHV0X25hbWVzIjogWyJhX2lucHV0IiwgImJfaW5wdXQiLCAidG1wLjQiXSwgIm91dHB1dF9uYW1lcyI6IFsiY19vdXRwdXQuMzUiXX0sICJmdW5jX25hbWUiOiAibmtpX3RlbnNvcl9hZGQiLCAiZ3JpZCI6IFtdLCAiaGFzX2NvbGxlY3RpdmVzIjogZmFsc2V9"
 9  %multiply.1 = f32[256,1024]{1,0} multiply(%multiply.0, %custom-call.0), metadata={op_name="jit(jax_customop_tutorial)/jit(main)/jit(jax_customop_tutorial)/mul" source_file="/home/ec2-user/private-aws-neuron-sdk-staging/nki/examples/tensor_addition/t4.py" source_line=9}
10  ROOT %tuple.11 = (f32[256,1024]{1,0}) tuple(%multiply.1)
11}

The Neuron compiler replaces the above custom-call with the corresponding NKI kernel implementation while optimizing the rest of the compute graph as usual. At the end of the compilation process, a single compiled binary NEFF file is generated representing the entire graph including the NKI kernel. For more information about NEFF files, see Neuron Compiler.

Using NKI in training graphs#

If you are using NKI to implement a new operator in a training graph, you might need to make the new operator interplay with the autograd engine in the framework. To do this, in PyTorch, you can subclass the framework’s base operator class and implement both the forward() and backward() methods. The autograd engine then uses the backward() method when performing auto-differentiation. See Extending torch.autograd in the PyTorch Docs for instructions on doing this in PyTorch. To do this in JAX, you can create a custom_vjp rule (vjp stands for Vector-Jacobian product), which binds the forward() and backward() calls. See Autodiff Cookbook in the JAX Docs for instructions on doing this.

Let’s reuse the nki_tensor_add kernels from before and demonstrate how to train a simple compute graph (a+b)*a*b in both PyTorch and JAX.

PyTorch#

We define a NkiAddFunc class, which leverages the nki_tensor_add kernel in its forward() function. The gradients of both input tensors in y = a + b are ones, so the backward() function propagates the dy gradients from the previous backward function.

import torch
import torch_xla.core.xla_model as xm
device = xm.xla_device()

class NkiAddFunc(torch.autograd.Function):
  @staticmethod
  def forward(ctx, a, b):
    return nki_tensor_add(a, b)

  @staticmethod
  def backward(ctx, dy, *args):
    # gradients for a and b
    return dy, dy

# now, let's define the compute graph
a = torch.randn(256, 1024, dtype=torch.float32).to(device).detach().requires_grad_()
b = torch.randn(256, 1024, dtype=torch.float32).to(device).detach().requires_grad_()
c = NkiAddFunc.apply(a, b)
out = a * b * c

# here we define a (dummy) loss-function, in prep for backward propagation
loss = out.sum()

# lastly, let's invoke the auto-grad engine
loss.backward()

xm.mark_step()

JAX#

We define a custom_vjp function nki_add_func by using the @jax.custom_vjp decorator which directly calls the nki_tensor_add kernel. We then define and register the forward() and backward() implementations of the nki_add_func function via defvjp(). Just like the PyTorch example before, the backward() implementation simply passes the gradients through. Finally, to start training, we execute the forward pass by calling nki_add_func(a, b) * x * y. To get the gradients, we call jax.grad directly with a loss function.

@jax.custom_vjp
def nki_add_func(a, b):
   return nki_tensor_add(a, b)

def f_forward(a, b):
   # operator output and residual (same as input here)
   return nki_add_func(a, b), (a, b)

def f_backward(res, grad):
   # gradients for a and b
   return grad, grad

nki_add_func.defvjp(f_forward, f_backward) # line 11

@jax.jit
def jax_customop_tutorial_and_grad(a, b):
   out = nki_add_func(a, b) * a * b

   # use the same dummy loss function (output sum) as PyTorch example above
   grad = jax.grad(lambda x, y: (nki_add_func(x, y) * x * y).sum(), argnums=(0, 1))(a, b)
   return out, *grad

c, grad_a, grad_b = jax_customop_tutorial_and_grad(a, b)