This document is relevant for: Inf2
, Trn1
, Trn2
NKI Kernel as a Framework Custom Operator#
This document demonstrates how to insert a NKI kernel as a custom operator into a PyTorch or JAX model using simple code examples.
Using NKI kernels#
To register a NKI kernel registration, you need to call a decorated NKI function.
Let’s examine a guiding example below where we
randomly initialize two inputs, add them together, and then
multiply the result by the two input tensors element-wise.
This effectively calculates: a * b * (a + b)
.
We define a common NKI kernel for addition. For more information on the kernel, see SPMD Tensor Addition.
1import neuronxcc.nki as nki
2import neuronxcc.nki.language as nl
3
4
5@nki.jit
6def nki_tensor_add_kernel_(a_input, b_input):
7 """NKI kernel to compute element-wise addition of two input tensors
8
9 This kernel assumes strict input/output sizes can be uniformly tiled to [128,512]
10
11 Args:
12 a_input: a first input tensor
13 b_input: a second input tensor
14
15 Returns:
16 c_output: an output tensor
17 """
18 # Create output tensor shared between all SPMD instances as result tensor
19 c_output = nl.ndarray(a_input.shape, dtype=a_input.dtype, buffer=nl.shared_hbm)
20
21 # Calculate tile offsets based on current 'program'
22 offset_i_x = nl.program_id(0) * 128
23 offset_i_y = nl.program_id(1) * 512
24
25 # Generate tensor indices to index tensors a and b
26 ix = offset_i_x + nl.arange(128)[:, None]
27 iy = offset_i_y + nl.arange(512)[None, :]
28
29 # Load input data from device memory (HBM) to on-chip memory (SBUF)
30 # We refer to an indexed portion of a tensor as an intermediate tensor
31 a_tile = nl.load(a_input[ix, iy])
32 b_tile = nl.load(b_input[ix, iy])
33
34 # compute a + b
35 c_tile = a_tile + b_tile
36
37 # store the addition results back to device memory (c_output)
38 nl.store(c_output[ix, iy], value=c_tile)
39
40 # Transfer the ownership of `c_output` to the caller
41 return c_output
PyTorch#
We can perform (a + b) * a * b
using native PyTorch code.
import torch
from torch_xla.core import xla_model as xm
device = xm.xla_device()
a = torch.randn(256, 1024, dtype=torch.float32).to(device)
b = torch.randn(256, 1024, dtype=torch.float32).to(device)
c = a + b
out = a * b * c
print(out)
Now let’s replace the tensor addition (c = a + b
) with a NKI
kernel.
To do this we replace the +
operator with a call to the NKI kernel
caller (nki_tensor_add
), and everything else works as before.
1def nki_tensor_add(a_input, b_input):
2 """NKI kernel caller to compute element-wise addition of two input tensors
3
4 This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs
5
6 Args:
7 a_input: a first input tensor, of shape [N*128, M*512]
8 b_input: a second input tensor, of shape [N*128, M*512]
9
10 Returns:
11 a tensor of shape [N*128, M*512], the result of a_input + b_input
12 """
13
14 # The SPMD launch grid denotes the number of kernel instances.
15 # In this case, we use a 2D grid where the size of each invocation is 128x512
16 grid_x = a_input.shape[0] // 128
17 grid_y = a_input.shape[1] // 512
18
19 return nki_tensor_add_kernel_[grid_x, grid_y](a_input, b_input)
device = xm.xla_device()
a = torch.randn(256, 1024, dtype=torch.float32).to(device)
b = torch.randn(256, 1024, dtype=torch.float32).to(device)
c = nki_tensor_add(a, b) # calling a NKI kernel, instead of the built-in torch op
out = a * b * c
print(out)
To understand what happens under the hood when we compile the above
code, we can print HLO IR graph generated by XLA by setting the
NEURON_FRAMEWORK_DEBUG
environment variable. For example, you may add the
following lines to your code:
import os
os.environ['NEURON_FRAMEWORK_DEBUG'] = "1"
A .pbtxt
file is then written in your run directory that has the
corresponding human-readable HLO IR.
Let’s examine the XLA output of this example.
In line #5 we can identify that the tensor addition is now
mapped to an HLO custom-call
instruction, with
AwsNeuronCustomNativeKernel
as custom_call_target
. The output of
that custom-call
is then consumed by the next instruction in line
#6 as usual.
1ENTRY %SyncTensorsGraph.22 (p0.2: f32[256,1024], p1.2: f32[256,1024]) -> (f32[256,1024]) {
2 %p1.2 = f32[256,1024]{1,0} parameter(1), frontend_attributes={neff_input_name="input1"}
3 %p0.2 = f32[256,1024]{1,0} parameter(0), frontend_attributes={neff_input_name="input0"}
4 %multiply = f32[256,1024]{1,0} multiply(f32[256,1024]{1,0} %p1.2, f32[256,1024]{1,0} %p0.2)
5 %custom-call.2 = f32[256,1024]{1,0} custom-call(f32[256,1024]{1,0} %p1.2, f32[256,1024]{1,0} %p0.2), custom_call_target="AwsNeuronCustomNativeKernel", api_version=API_VERSION_UNSPECIFIED, backend_config="...")
6 %multiply.1 = f32[256,1024]{1,0} multiply(f32[256,1024]{1,0} %multiply, f32[256,1024]{1,0} %custom-call.2)
7 ROOT %tuple = (f32[256,1024]{1,0}) tuple(f32[256,1024]{1,0} %multiply.1), frontend_attributes={neff_output_names="output0"}
8}
The Neuron compiler replaces the above custom-call with the corresponding NKI kernel implementation while optimizing the rest of the compute graph as usual. At the end of the compilation process, a single compiled binary NEFF file is generated representing the entire graph including the NKI kernel. For more information about NEFF files, see Neuron Compiler.
JAX#
We can perform (a + b) * a * b
using native JAX code.
import jax
import jax.numpy as jnp
@jax.jit
def jax_customop_tutorial(a, b):
c = a + b
out = a * b * c
return out
seed = jax.random.PRNGKey(0)
seed_a, seed_b = jax.random.split(seed)
a = jax.random.normal(seed_a, (256, 1024), dtype=jnp.float32)
b = jax.random.normal(seed_b, (256, 1024), dtype=jnp.float32)
print(jax_customop_tutorial(a, b))
Similar to the PyTorch example above, let’s replace the tensor addition (c = a + b)
with
the addition NKI kernel. To do this we replace the +
operator with a call to the NKI kernel
caller (nki_tensor_add
), and everything else works as before.
1def nki_tensor_add(a_input, b_input):
2 """NKI kernel caller to compute element-wise addition of two input tensors
3
4 This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs
5
6 Args:
7 a_input: a first input tensor, of shape [N*128, M*512]
8 b_input: a second input tensor, of shape [N*128, M*512]
9
10 Returns:
11 a tensor of shape [N*128, M*512], the result of a_input + b_input
12 """
13
14 # The SPMD launch grid denotes the number of kernel instances.
15 # In this case, we use a 2D grid where the size of each invocation is 128x512
16 grid_x = a_input.shape[0] // 128
17 grid_y = a_input.shape[1] // 512
18
19 return nki_tensor_add_kernel_[grid_x, grid_y](a_input, b_input)
import jax
import jax.numpy as jnp
@jax.jit
def jax_customop_tutorial(a, b):
c = nki_tensor_add(a, b) # calling a NKI kernel, instead of the built-in jax op
out = a * b * c
return out
seed = jax.random.PRNGKey(0)
seed_a, seed_b = jax.random.split(seed)
a = jax.random.normal(seed_a, (256, 1024), dtype=jnp.float32)
b = jax.random.normal(seed_b, (256, 1024), dtype=jnp.float32)
print(jax_customop_tutorial(a, b))
To understand what happens under the hood when we compile the above code, we can print the HLO IR graph by adding the following snippet to your code:
print(jax.jit(jax_customop_tutorial)
.lower(a, b)
.compile()
.runtime_executable()
.hlo_modules()[0].to_string()
)
Let’s examine the XLA output of this example.
In line #7 we can identify that the tensor addition is now
mapped to an HLO custom-call
instruction, similar to PyTorch. The output of
that custom-call
is then consumed by the next instruction in line
#8 as usual.
1HloModule jit_add, entry_computation_layout={(f32[256,1024]{1,0}, f32[256,1024]{1,0})->(f32[256,1024]{1,0})}, allow_spmd_sharding_propagation_to_output={true}
2
3ENTRY %main.11 (Arg_0.1: f32[256,1024], Arg_1.2: f32[256,1024]) -> (f32[256,1024]) {
4 %Arg_0.1 = f32[256,1024]{1,0} parameter(0), sharding={replicated}
5 %Arg_1.2 = f32[256,1024]{1,0} parameter(1), sharding={replicated}
6 %multiply.0 = f32[256,1024]{1,0} multiply(f32[256,1024]{1,0} %Arg_0.1, f32[256,1024]{1,0} %Arg_1.2), metadata={op_name="jit(add)/jit(main)/jit(jax_customop_tutorial)/mul" source_file="/tmp/ipykernel_3935360/2333914945.py" source_line=61}
7 %custom-call.0 = f32[256,1024]{1,0} custom-call(f32[256,1024]{1,0} %Arg_0.1, f32[256,1024]{1,0} %Arg_1.2), custom_call_target="AwsNeuronCustomNativeKernel", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(add)/jit(main)/jit(jax_customop_tutorial)/nki_call[func=<function nki_tensor_add_kernel_ at 0x7f6be28f6f80> grid=(2, 2) out_shape=(ShapeDtypeStruct(shape=(256, 1024), dtype=float32),)]" source_file="/home/ubuntu/nki/src/jax_neuronx/core.py" source_line=34}, backend_config="..."
8 %multiply.1 = f32[256,1024]{1,0} multiply(f32[256,1024]{1,0} %multiply.0, f32[256,1024]{1,0} %custom-call.0), metadata={op_name="jit(add)/jit(main)/jit(jax_customop_tutorial)/mul" source_file="/tmp/ipykernel_3935360/2333914945.py" source_line=61}
9 ROOT %tuple.10 = (f32[256,1024]{1,0}) tuple(f32[256,1024]{1,0} %multiply.1)
10}
The Neuron compiler replaces the above custom-call with the corresponding NKI kernel implementation while optimizing the rest of the compute graph as usual. At the end of the compilation process, a single compiled binary NEFF file is generated representing the entire graph including the NKI kernel. For more information about NEFF files, see Neuron Compiler.
Using NKI in training graphs#
If you are using NKI to implement a new operator in a training graph,
you might need to make the new operator interplay with the
autograd
engine in the framework. To do this, in PyTorch, you can
subclass the framework’s base operator class and implement both the forward()
and backward()
methods. The autograd
engine then uses the backward()
method when performing auto-differentiation. See
Extending torch.autograd in the
PyTorch Docs for instructions on doing this in PyTorch. To do this in JAX,
you can create a custom_vjp
rule (vjp stands for Vector-Jacobian product), which binds the
forward()
and backward()
calls. See
Autodiff Cookbook in
the JAX Docs for instructions on doing this.
Let’s reuse the nki_tensor_add
kernels from before and demonstrate how to train a
simple compute graph (a+b)*a*b
in both PyTorch and JAX.
PyTorch#
We define a NkiAddFunc
class, which leverages the nki_tensor_add
kernel in its forward()
function. The gradients of both input tensors in y = a + b
are
ones, so the backward()
function
propagates the dy
gradients from the previous backward function.
import torch
import torch_xla.core.xla_model as xm
device = xm.xla_device()
class NkiAddFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
return nki_tensor_add(a, b)
@staticmethod
def backward(ctx, dy, *args):
# gradients for a and b
return dy, dy
# now, let's define the compute graph
a = torch.randn(256, 1024, dtype=torch.float32).to(device).detach().requires_grad_()
b = torch.randn(256, 1024, dtype=torch.float32).to(device).detach().requires_grad_()
c = NkiAddFunc.apply(a, b)
out = a * b * c
# here we define a (dummy) loss-function, in prep for backward propagation
loss = out.sum()
# lastly, let's invoke the auto-grad engine
loss.backward()
xm.mark_step()
JAX#
We define a custom_vjp
function nki_add_func
by using
the @jax.custom_vjp
decorator which directly calls
the nki_tensor_add
kernel. We then define and register
the forward()
and backward()
implementations of the
nki_add_func
function via defvjp()
. Just like the PyTorch
example before, the backward()
implementation simply passes
the gradients through. Finally, to start training, we execute the
forward pass by calling nki_add_func(a, b) * x * y
.
To get the gradients, we call jax.grad
directly with a loss function.
@jax.custom_vjp
def nki_add_func(a, b):
return nki_tensor_add(a, b)
def f_forward(a, b):
# operator output and residual (same as input here)
return nki_add_func(a, b), (a, b)
def f_backward(res, grad):
# gradients for a and b
return grad, grad
nki_add_func.defvjp(f_forward, f_backward) # line 11
@jax.jit
def jax_customop_tutorial_and_grad(a, b):
out = nki_add_func(a, b) * x * y
# use the same dummy loss function (output sum) as PyTorch example above
grad = jax.grad(lambda x, y: (nki_add_func(x, y) * x * y).sum(), argnums=(0, 1))(a, b)
return out, *grad
c, grad_a, grad_b = jax_customop_tutorial_and_grad(a, b)
This document is relevant for: Inf2
, Trn1
, Trn2