This document is relevant for: Inf2, Trn1, Trn1n

Tensor addition#

In this tutorial we write a simple tensor addition kernel using NKI in PyTorch and JAX. In doing so, we learn about:

  • The NKI syntax and the SPMD programming model.

  • Best practices for validating and benchmarking your custom kernel against a reference native PyTorch or JAX implementation.

Note

This tutorial is written using the SPMD programming model in NKI. However, as discussed in NKI programming guide, adopting the SPMD programming model has no impact on performance of NKI kernel, and therefore is considered optional in current NKI release.

PyTorch#

Compute kernel#

We start by defining the compute kernel, that operates on a tile size of [128, 512]. The partition dimension tile size is chosen according to the tile size restrictions (nki.language.tile_size.pmax), while the free dimension tile size is chosen arbitrarily (512).

 1import neuronxcc.nki.language as nl
 2
 3
 4def nki_tensor_add_kernel_(a_input, b_input, c_output):
 5  """NKI kernel to compute element-wise addition of two input tensors
 6
 7  This kernel assumes strict input/output tile-sizes, of up-to [128,512]
 8
 9  Args:
10      a_input: a first input tensor, of shape [128,512]
11      b_input: a second input tensor, of shape [128,512]
12      c_output: an output tensor, of shape [128,512]
13  """
14
15  # Calculate tile offsets based on current 'program'
16  offset_i_x = nl.program_id(0) * 128
17  offset_i_y = nl.program_id(1) * 512
18
19  # Generate tensor indices to index tensors a and b
20  ix = offset_i_x + nl.arange(128)[:, None]
21  iy = offset_i_y + nl.arange(512)[None, :]
22
23  # Load input data from device memory (HBM) to on-chip memory (SBUF)
24  # We refer to an indexed portion of a tensor as an intermediate tensor
25  a_tile = nl.load(a_input[ix, iy])
26  b_tile = nl.load(b_input[ix, iy])
27
28  # compute a + b
29  c_tile = a_tile + b_tile
30
31  # store the addition results back to device memory (c_output)
32  nl.store(c_output[ix, iy], value=c_tile)

In this example:

  1. We define the NKI kernel in nki_tensor_add_kernel_.

  2. Inside, we first define offsets into the tensors, based on the ID of the worker executing the code (nl.program_id), and generate tile indices using these offsets with nl.arange. We use advanced indexing here to showcase how it works. Basic indexing with slicing can also work. See NKI Programming Model for more information on different tensor indexing modes.

  3. We use nl.program_id to enable SPMD execution (single-program, multiple-data, see SPMD: Launching Multiple Instances of a Kernel), where each worker only operates on a (sub-tensor) tile of the input/output tensors. By accessing its own program_id, each worker can calculate the offsets it needs to access the correct tiles.

  4. The first axis of the tensor (mapped to the partition-dimension) is tiled into blocks of 128, based on hardware restrictions (see Tile Size Considerations). The second axis (mapped to the free-dimension) is tiled into blocks of 512 (no tile-size constraint, except for on-chip memory capacity constraints).

  5. We then load sub-tensors data from tensors a_input and b_input using nl.load, to place the tiles a_tile and b_tile in the on-chip memory (SBUF)

  6. We sum them to compute c_tile, and store it back to DRAM in the relevant portion of the c_output tensor, using nl.store. Since both inputs and output are the same shape, we can use the same set of indices to access all three tensors.

SPMD execution#

We declare a helper function, to allocate the result tensor, and launch the compute-kernel with appropriate grid/block sizes, to perform the computation. We wrap the nki_tensor_add_kernel_ kernel with a nki_jit decorator so that the kernel can be traced and called from PyTorch.

 1import torch
 2from torch_neuronx import nki_jit
 3def nki_tensor_add(a_input, b_input):
 4  """NKI kernel caller to compute element-wise addition of two input tensors
 5
 6  This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs
 7
 8  Args:
 9      a_input: a first input tensor, of shape [N*128, M*512]
10      b_input: a second input tensor, of shape [N*128, M*512]
11
12  Returns:
13      a tensor of shape [N*128, M*512], the result of a_input + b_input
14  """
15
16  # The SPMD launch grid denotes the number of kernel instances.
17  # In this case, we use a 2D grid where the size of each invocation is 128x512
18  grid_x = a_input.shape[0] // 128
19  grid_y = a_input.shape[1] // 512
20  c_output = torch.zeros(a_input.shape, dtype=a_input.dtype).to(device=device)
21
22  # Decorate the NKI kernel for PyTorch tracing
23  nki_tensor_add_kernel_torch = nki_jit(nki_tensor_add_kernel_)
24  nki_tensor_add_kernel_torch[grid_x, grid_y](a_input, b_input, c_output)
25
26  return c_output

We are using a two-dimensional grid, where the first dimension of the tensor is tiled in the X dimension of the grid, while the second dimension is tiled in the Y dimension of the grid. In this scenario we assume that tensor sizes are a multiple of maximum tile sizes allowed, so we do not need to handle partial tiles. We explicitly initialize c_output to zeros using torch.zeros. Note, we cannot use nl.zeros here as this function is not traced as a NKI kernel with nki_jit.

Launching kernel and testing correctness#

To execute the kernel, we prepare tensors a and b, and call the nki_tensor_add helper function. We also verify the correctness of the NKI kernel against, torch by comparing the outputs of both, using torch.allclose:

 1import torch
 2from torch_xla.core import xla_model as xm
 3if __name__ == "__main__":
 4  device = xm.xla_device()
 5
 6  a = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device)
 7  b = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device)
 8
 9  output_nki = nki_tensor_add(a, b)
10  print(f"output_nki={output_nki}")
11
12  output_torch = a + b
13  print(f"output_torch={output_torch}")
14
15  allclose = torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2)
16  if allclose:
17    print("NKI and Torch match")
18  else:
19    print("NKI and Torch differ")
20
21  assert allclose

Output:

2023-12-29 15:18:00.000558:  14283  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2023-12-29 15:18:00.000559:  14283  INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: ['neuronx-cc', '--target=trn1', 'compile', '--framework', 'XLA', '/tmp/neuroncc_compile_workdir/49f554a2-2c55-4a88-8054-cc9f20824a46/model.MODULE_5007921933048625946+d41d8cd9.hlo.pb', '--output', '/tmp/neuroncc_compile_workdir/49f554a2-2c55-4a88-8054-cc9f20824a46/model.MODULE_5007921933048625946+d41d8cd9.neff', '--verbose=35']
.
Compiler status PASS
output_nki=tensor([[0.9297, 0.8359, 1.1719,  ..., 0.4648, 0.2188, 0.9336],
        [0.3906, 1.3125, 0.8789,  ..., 1.6562, 1.7734, 0.9531],
        [0.6445, 1.1406, 1.3281,  ..., 0.9531, 0.8711, 0.9336],
        ...,
        [0.4023, 0.6406, 1.5312,  ..., 0.7617, 0.7734, 0.3359],
        [0.8125, 0.7422, 1.2109,  ..., 0.8516, 1.2031, 0.5430],
        [1.3281, 1.2812, 1.3984,  ..., 1.2344, 0.8711, 0.5664]],
       device='xla:1', dtype=torch.bfloat16)
2023-12-29 15:18:02.000219:  14463  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2023-12-29 15:18:02.000220:  14463  INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: ['neuronx-cc', '--target=trn1', 'compile', '--framework', 'XLA', '/tmp/neuroncc_compile_workdir/2e135b73-1c3b-45e4-a6f0-2c4b105c20e5/model.MODULE_10032327759287407517+d41d8cd9.hlo.pb', '--output', '/tmp/neuroncc_compile_workdir/2e135b73-1c3b-45e4-a6f0-2c4b105c20e5/model.MODULE_10032327759287407517+d41d8cd9.neff', '--verbose=35']
.
Compiler status PASS
output_torch=tensor([[0.9297, 0.8359, 1.1719,  ..., 0.4648, 0.2188, 0.9336],
        [0.3906, 1.3125, 0.8789,  ..., 1.6562, 1.7734, 0.9531],
        [0.6445, 1.1406, 1.3281,  ..., 0.9531, 0.8711, 0.9336],
        ...,
        [0.4023, 0.6406, 1.5312,  ..., 0.7617, 0.7734, 0.3359],
        [0.8125, 0.7422, 1.2109,  ..., 0.8516, 1.2031, 0.5430],
        [1.3281, 1.2812, 1.3984,  ..., 1.2344, 0.8711, 0.5664]],
       device='xla:1', dtype=torch.bfloat16)
2023-12-29 15:18:03.000797:  14647  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2023-12-29 15:18:03.000798:  14647  INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: ['neuronx-cc', '--target=trn1', 'compile', '--framework', 'XLA', '/tmp/neuroncc_compile_workdir/74f8b6ae-76d9-4dd8-af7f-e5e1c40a27a3/model.MODULE_5906037506311912405+d41d8cd9.hlo.pb', '--output', '/tmp/neuroncc_compile_workdir/74f8b6ae-76d9-4dd8-af7f-e5e1c40a27a3/model.MODULE_5906037506311912405+d41d8cd9.neff', '--verbose=35']
.
Compiler status PASS
NKI and Torch match

Note that the tensor values you see will differ from what’s printed above, since this example uses torch.rand to initialize the inputs.

JAX#

Compute kernel#

We can reuse the same NKI compute kernel defined for PyTorch above.

 1import neuronxcc.nki.language as nl
 2
 3
 4def nki_tensor_add_kernel_(a_input, b_input, c_output):
 5  """NKI kernel to compute element-wise addition of two input tensors
 6
 7  This kernel assumes strict input/output tile-sizes, of up-to [128,512]
 8
 9  Args:
10      a_input: a first input tensor, of shape [128,512]
11      b_input: a second input tensor, of shape [128,512]
12      c_output: an output tensor, of shape [128,512]
13  """
14
15  # Calculate tile offsets based on current 'program'
16  offset_i_x = nl.program_id(0) * 128
17  offset_i_y = nl.program_id(1) * 512
18
19  # Generate tensor indices to index tensors a and b
20  ix = offset_i_x + nl.arange(128)[:, None]
21  iy = offset_i_y + nl.arange(512)[None, :]
22
23  # Load input data from device memory (HBM) to on-chip memory (SBUF)
24  # We refer to an indexed portion of a tensor as an intermediate tensor
25  a_tile = nl.load(a_input[ix, iy])
26  b_tile = nl.load(b_input[ix, iy])
27
28  # compute a + b
29  c_tile = a_tile + b_tile
30
31  # store the addition results back to device memory (c_output)
32  nl.store(c_output[ix, iy], value=c_tile)

SPMD execution#

Now we can also declare a helper function, to allocate the result tensor, and launch the compute-kernel with appropriate grid/block sizes, to perform the computation:

 1import jax
 2from jax_neuronx import nki_call
 3def nki_tensor_add(a_input, b_input):
 4  """NKI kernel caller to compute element-wise addition of two input tensors
 5
 6  This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs
 7
 8  Args:
 9      a_input: a first input tensor, of shape [N*128, M*512]
10      b_input: a second input tensor, of shape [N*128, M*512]
11
12  Returns:
13      a tensor of shape [N*128, M*512], the result of a_input + b_input
14  """
15
16  # The SPMD launch grid denotes the number of kernel instances.
17  # In this case, we use a 2D grid where the size of each invocation is 128x512
18  grid_x = a_input.shape[0] // 128
19  grid_y = a_input.shape[1] // 512
20
21  out_shape = jax.ShapeDtypeStruct((a_input.shape[0], a_input.shape[1]), dtype=a_input.dtype)
22
23  return nki_call(
24      nki_tensor_add_kernel_,
25      a_input,
26      b_input,
27      grid=(grid_x, grid_y),
28      out_shape=out_shape,
29    )

We are using a two-dimensional grid, where the first dimension of the tensor is tiled in the X dimension of the grid, while the second dimension is tiled in the Y dimension of the grid. In this scenario we assume that tensor sizes are a multiple of maximum tile sizes allowed, so we do not need to handle partial tiles. We initialize out_shape representing the shape of the output using jax.ShapeDtypeStruct.

Launching kernel and testing correctness#

To execute the kernel, we prepare arrays a and b, and call the nki_tensor_add helper function. We also verify the correctness of the NKI kernel against, JAX by comparing the outputs of both, using jax.numpy.allclose:

 1import jax
 2import jax.numpy as jnp
 3if __name__ == "__main__":
 4
 5  seed_a, seed_b = jax.random.split(jax.random.PRNGKey(42))
 6  a = jax.random.uniform(seed_a, (256, 1024), dtype=jnp.bfloat16)
 7  b = jax.random.uniform(seed_b, (256, 1024), dtype=jnp.bfloat16)
 8
 9  output_nki = nki_tensor_add(a, b)
10  print(f"output_nki={output_nki}")
11
12  output_jax = a + b
13  print(f"output_jax={output_jax}")
14
15  allclose = jnp.allclose(output_jax, output_nki, atol=1e-4, rtol=1e-2)
16  if allclose:
17    print("NKI and JAX match")
18  else:
19    print("NKI and JAX differ")
20
21  assert allclose

Output:

.
Compiler status PASS
.
Compiler status PASS
.
Compiler status PASS
output_nki=[[0.992188 1.27344 1.65625 ... 0.90625 1.34375 1.77344]
 [0 0.90625 1.34375 ... 0.390625 0.703125 0.914062]
 [0.5 0.390625 0.703125 ... 1.22656 1.15625 1.01562]
 ...
 [1.98438 1.98438 1.98438 ... 1.33594 1.64062 1.35938]
 [0.992188 1.33594 1.64062 ... 1.16406 1.67188 1.20312]
 [1.49219 1.16406 1.67188 ... 1.375 1 1.6875]]
.
Compiler status PASS
output_jax=[[0.992188 1.27344 1.65625 ... 0.90625 1.34375 1.77344]
 [0 0.90625 1.34375 ... 0.390625 0.703125 0.914062]
 [0.5 0.390625 0.703125 ... 1.22656 1.15625 1.01562]
 ...
 [1.98438 1.98438 1.98438 ... 1.33594 1.64062 1.35938]
 [0.992188 1.33594 1.64062 ... 1.16406 1.67188 1.20312]
 [1.49219 1.16406 1.67188 ... 1.375 1 1.6875]]
.
Compiler status PASS
NKI and JAX match

Note that the array values you see will differ from what’s printed above, since this example uses jax.random.uniform to initialize the inputs.

Download All Source Code#

Click the links to download source code of the kernels and the testing code discussed in this tutorial.

You can also view the source code in the Github repository nki_samples

Example usage of the scripts:#

Run NKI baremetal implementation:

python3 tensor_addition_nki_kernels.py

Run PyTorch implementation:

python3 tensor_addition_torch.py

Run JAX implementation:

python3 tensor_addition_jax.py

This document is relevant for: Inf2, Trn1, Trn1n