This document is relevant for: Inf2
, Trn1
, Trn1n
Tensor addition#
In this tutorial we write a simple tensor addition kernel using NKI in PyTorch and JAX. In doing so, we learn about:
The NKI syntax and the SPMD programming model.
Best practices for validating and benchmarking your custom kernel against a reference native PyTorch or JAX implementation.
Note
This tutorial is written using the SPMD programming model in NKI. However, as discussed in NKI programming guide, adopting the SPMD programming model has no impact on performance of NKI kernel, and therefore is considered optional in current NKI release.
PyTorch#
Compute kernel#
We start by defining the compute kernel, that operates on a tile size of
[128, 512]
. The partition dimension tile size is chosen according to the tile size
restrictions (nki.language.tile_size.pmax),
while the free dimension tile size is chosen arbitrarily (512
).
1import neuronxcc.nki.language as nl
2
3
4def nki_tensor_add_kernel_(a_input, b_input, c_output):
5 """NKI kernel to compute element-wise addition of two input tensors
6
7 This kernel assumes strict input/output tile-sizes, of up-to [128,512]
8
9 Args:
10 a_input: a first input tensor, of shape [128,512]
11 b_input: a second input tensor, of shape [128,512]
12 c_output: an output tensor, of shape [128,512]
13 """
14
15 # Calculate tile offsets based on current 'program'
16 offset_i_x = nl.program_id(0) * 128
17 offset_i_y = nl.program_id(1) * 512
18
19 # Generate tensor indices to index tensors a and b
20 ix = offset_i_x + nl.arange(128)[:, None]
21 iy = offset_i_y + nl.arange(512)[None, :]
22
23 # Load input data from device memory (HBM) to on-chip memory (SBUF)
24 # We refer to an indexed portion of a tensor as an intermediate tensor
25 a_tile = nl.load(a_input[ix, iy])
26 b_tile = nl.load(b_input[ix, iy])
27
28 # compute a + b
29 c_tile = a_tile + b_tile
30
31 # store the addition results back to device memory (c_output)
32 nl.store(c_output[ix, iy], value=c_tile)
In this example:
We define the NKI kernel in
nki_tensor_add_kernel_
.Inside, we first define offsets into the tensors, based on the ID of the worker executing the code (
nl.program_id
), and generate tile indices using these offsets withnl.arange
. We use advanced indexing here to showcase how it works. Basic indexing with slicing can also work. See NKI Programming Model for more information on different tensor indexing modes.We use
nl.program_id
to enable SPMD execution (single-program, multiple-data, see SPMD: Launching Multiple Instances of a Kernel), where each worker only operates on a (sub-tensor) tile of the input/output tensors. By accessing its ownprogram_id
, each worker can calculate the offsets it needs to access the correct tiles.The first axis of the tensor (mapped to the partition-dimension) is tiled into blocks of 128, based on hardware restrictions (see Tile Size Considerations). The second axis (mapped to the free-dimension) is tiled into blocks of 512 (no tile-size constraint, except for on-chip memory capacity constraints).
We then load sub-tensors data from tensors
a_input
andb_input
usingnl.load
, to place the tilesa_tile
andb_tile
in the on-chip memory (SBUF)We sum them to compute
c_tile
, and store it back to DRAM in the relevant portion of thec_output
tensor, usingnl.store
. Since both inputs and output are the same shape, we can use the same set of indices to access all three tensors.
SPMD execution#
We declare a helper function, to allocate the result
tensor, and launch the compute-kernel with appropriate grid/block sizes,
to perform the computation. We wrap the nki_tensor_add_kernel_
kernel
with a nki_jit
decorator so that the kernel can be traced and called from PyTorch.
1import torch
2from torch_neuronx import nki_jit
3def nki_tensor_add(a_input, b_input):
4 """NKI kernel caller to compute element-wise addition of two input tensors
5
6 This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs
7
8 Args:
9 a_input: a first input tensor, of shape [N*128, M*512]
10 b_input: a second input tensor, of shape [N*128, M*512]
11
12 Returns:
13 a tensor of shape [N*128, M*512], the result of a_input + b_input
14 """
15
16 # The SPMD launch grid denotes the number of kernel instances.
17 # In this case, we use a 2D grid where the size of each invocation is 128x512
18 grid_x = a_input.shape[0] // 128
19 grid_y = a_input.shape[1] // 512
20 c_output = torch.zeros(a_input.shape, dtype=a_input.dtype).to(device=device)
21
22 # Decorate the NKI kernel for PyTorch tracing
23 nki_tensor_add_kernel_torch = nki_jit(nki_tensor_add_kernel_)
24 nki_tensor_add_kernel_torch[grid_x, grid_y](a_input, b_input, c_output)
25
26 return c_output
We are using a two-dimensional grid, where the first dimension of the
tensor is tiled in the X dimension of the grid, while the second
dimension is tiled in the Y dimension of the grid. In this scenario we
assume that tensor sizes are a multiple of maximum tile sizes allowed,
so we do not need to handle partial tiles. We explicitly initialize
c_output
to zeros using torch.zeros
. Note, we cannot use
nl.zeros
here as this function is not traced as a NKI kernel with nki_jit
.
Launching kernel and testing correctness#
To execute the kernel, we prepare tensors a
and b
, and call the
nki_tensor_add
helper function. We also verify the correctness of the NKI kernel against, torch by
comparing the outputs of both, using torch.allclose
:
1import torch
2from torch_xla.core import xla_model as xm
3if __name__ == "__main__":
4 device = xm.xla_device()
5
6 a = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device)
7 b = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device)
8
9 output_nki = nki_tensor_add(a, b)
10 print(f"output_nki={output_nki}")
11
12 output_torch = a + b
13 print(f"output_torch={output_torch}")
14
15 allclose = torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2)
16 if allclose:
17 print("NKI and Torch match")
18 else:
19 print("NKI and Torch differ")
20
21 assert allclose
Output:
2023-12-29 15:18:00.000558: 14283 INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2023-12-29 15:18:00.000559: 14283 INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: ['neuronx-cc', '--target=trn1', 'compile', '--framework', 'XLA', '/tmp/neuroncc_compile_workdir/49f554a2-2c55-4a88-8054-cc9f20824a46/model.MODULE_5007921933048625946+d41d8cd9.hlo.pb', '--output', '/tmp/neuroncc_compile_workdir/49f554a2-2c55-4a88-8054-cc9f20824a46/model.MODULE_5007921933048625946+d41d8cd9.neff', '--verbose=35']
.
Compiler status PASS
output_nki=tensor([[0.9297, 0.8359, 1.1719, ..., 0.4648, 0.2188, 0.9336],
[0.3906, 1.3125, 0.8789, ..., 1.6562, 1.7734, 0.9531],
[0.6445, 1.1406, 1.3281, ..., 0.9531, 0.8711, 0.9336],
...,
[0.4023, 0.6406, 1.5312, ..., 0.7617, 0.7734, 0.3359],
[0.8125, 0.7422, 1.2109, ..., 0.8516, 1.2031, 0.5430],
[1.3281, 1.2812, 1.3984, ..., 1.2344, 0.8711, 0.5664]],
device='xla:1', dtype=torch.bfloat16)
2023-12-29 15:18:02.000219: 14463 INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2023-12-29 15:18:02.000220: 14463 INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: ['neuronx-cc', '--target=trn1', 'compile', '--framework', 'XLA', '/tmp/neuroncc_compile_workdir/2e135b73-1c3b-45e4-a6f0-2c4b105c20e5/model.MODULE_10032327759287407517+d41d8cd9.hlo.pb', '--output', '/tmp/neuroncc_compile_workdir/2e135b73-1c3b-45e4-a6f0-2c4b105c20e5/model.MODULE_10032327759287407517+d41d8cd9.neff', '--verbose=35']
.
Compiler status PASS
output_torch=tensor([[0.9297, 0.8359, 1.1719, ..., 0.4648, 0.2188, 0.9336],
[0.3906, 1.3125, 0.8789, ..., 1.6562, 1.7734, 0.9531],
[0.6445, 1.1406, 1.3281, ..., 0.9531, 0.8711, 0.9336],
...,
[0.4023, 0.6406, 1.5312, ..., 0.7617, 0.7734, 0.3359],
[0.8125, 0.7422, 1.2109, ..., 0.8516, 1.2031, 0.5430],
[1.3281, 1.2812, 1.3984, ..., 1.2344, 0.8711, 0.5664]],
device='xla:1', dtype=torch.bfloat16)
2023-12-29 15:18:03.000797: 14647 INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2023-12-29 15:18:03.000798: 14647 INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: ['neuronx-cc', '--target=trn1', 'compile', '--framework', 'XLA', '/tmp/neuroncc_compile_workdir/74f8b6ae-76d9-4dd8-af7f-e5e1c40a27a3/model.MODULE_5906037506311912405+d41d8cd9.hlo.pb', '--output', '/tmp/neuroncc_compile_workdir/74f8b6ae-76d9-4dd8-af7f-e5e1c40a27a3/model.MODULE_5906037506311912405+d41d8cd9.neff', '--verbose=35']
.
Compiler status PASS
NKI and Torch match
Note that the tensor values you see will differ from what’s printed above, since this example uses torch.rand to initialize the inputs.
JAX#
Compute kernel#
We can reuse the same NKI compute kernel defined for PyTorch above.
1import neuronxcc.nki.language as nl
2
3
4def nki_tensor_add_kernel_(a_input, b_input, c_output):
5 """NKI kernel to compute element-wise addition of two input tensors
6
7 This kernel assumes strict input/output tile-sizes, of up-to [128,512]
8
9 Args:
10 a_input: a first input tensor, of shape [128,512]
11 b_input: a second input tensor, of shape [128,512]
12 c_output: an output tensor, of shape [128,512]
13 """
14
15 # Calculate tile offsets based on current 'program'
16 offset_i_x = nl.program_id(0) * 128
17 offset_i_y = nl.program_id(1) * 512
18
19 # Generate tensor indices to index tensors a and b
20 ix = offset_i_x + nl.arange(128)[:, None]
21 iy = offset_i_y + nl.arange(512)[None, :]
22
23 # Load input data from device memory (HBM) to on-chip memory (SBUF)
24 # We refer to an indexed portion of a tensor as an intermediate tensor
25 a_tile = nl.load(a_input[ix, iy])
26 b_tile = nl.load(b_input[ix, iy])
27
28 # compute a + b
29 c_tile = a_tile + b_tile
30
31 # store the addition results back to device memory (c_output)
32 nl.store(c_output[ix, iy], value=c_tile)
SPMD execution#
Now we can also declare a helper function, to allocate the result tensor, and launch the compute-kernel with appropriate grid/block sizes, to perform the computation:
1import jax
2from jax_neuronx import nki_call
3def nki_tensor_add(a_input, b_input):
4 """NKI kernel caller to compute element-wise addition of two input tensors
5
6 This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs
7
8 Args:
9 a_input: a first input tensor, of shape [N*128, M*512]
10 b_input: a second input tensor, of shape [N*128, M*512]
11
12 Returns:
13 a tensor of shape [N*128, M*512], the result of a_input + b_input
14 """
15
16 # The SPMD launch grid denotes the number of kernel instances.
17 # In this case, we use a 2D grid where the size of each invocation is 128x512
18 grid_x = a_input.shape[0] // 128
19 grid_y = a_input.shape[1] // 512
20
21 out_shape = jax.ShapeDtypeStruct((a_input.shape[0], a_input.shape[1]), dtype=a_input.dtype)
22
23 return nki_call(
24 nki_tensor_add_kernel_,
25 a_input,
26 b_input,
27 grid=(grid_x, grid_y),
28 out_shape=out_shape,
29 )
We are using a two-dimensional grid, where the first dimension of the
tensor is tiled in the X dimension of the grid, while the second
dimension is tiled in the Y dimension of the grid. In this scenario we
assume that tensor sizes are a multiple of maximum tile sizes allowed,
so we do not need to handle partial tiles. We initialize out_shape
representing the shape of the output using jax.ShapeDtypeStruct
.
Launching kernel and testing correctness#
To execute the kernel, we prepare arrays a
and b
, and call the
nki_tensor_add
helper function. We also verify the correctness of the NKI kernel against, JAX by
comparing the outputs of both, using jax.numpy.allclose
:
1import jax
2import jax.numpy as jnp
3if __name__ == "__main__":
4
5 seed_a, seed_b = jax.random.split(jax.random.PRNGKey(42))
6 a = jax.random.uniform(seed_a, (256, 1024), dtype=jnp.bfloat16)
7 b = jax.random.uniform(seed_b, (256, 1024), dtype=jnp.bfloat16)
8
9 output_nki = nki_tensor_add(a, b)
10 print(f"output_nki={output_nki}")
11
12 output_jax = a + b
13 print(f"output_jax={output_jax}")
14
15 allclose = jnp.allclose(output_jax, output_nki, atol=1e-4, rtol=1e-2)
16 if allclose:
17 print("NKI and JAX match")
18 else:
19 print("NKI and JAX differ")
20
21 assert allclose
Output:
.
Compiler status PASS
.
Compiler status PASS
.
Compiler status PASS
output_nki=[[0.992188 1.27344 1.65625 ... 0.90625 1.34375 1.77344]
[0 0.90625 1.34375 ... 0.390625 0.703125 0.914062]
[0.5 0.390625 0.703125 ... 1.22656 1.15625 1.01562]
...
[1.98438 1.98438 1.98438 ... 1.33594 1.64062 1.35938]
[0.992188 1.33594 1.64062 ... 1.16406 1.67188 1.20312]
[1.49219 1.16406 1.67188 ... 1.375 1 1.6875]]
.
Compiler status PASS
output_jax=[[0.992188 1.27344 1.65625 ... 0.90625 1.34375 1.77344]
[0 0.90625 1.34375 ... 0.390625 0.703125 0.914062]
[0.5 0.390625 0.703125 ... 1.22656 1.15625 1.01562]
...
[1.98438 1.98438 1.98438 ... 1.33594 1.64062 1.35938]
[0.992188 1.33594 1.64062 ... 1.16406 1.67188 1.20312]
[1.49219 1.16406 1.67188 ... 1.375 1 1.6875]]
.
Compiler status PASS
NKI and JAX match
Note that the array values you see will differ from what’s printed above, since this example uses jax.random.uniform to initialize the inputs.
Download All Source Code#
Click the links to download source code of the kernels and the testing code discussed in this tutorial.
NKI baremetal implementation:
tensor_addition_nki_kernels.py
- PyTorch implementation:
tensor_addition_torch.py
You must also download
tensor_addition_nki_kernels.py
into the same folder to run this PyTorch script.
- PyTorch implementation:
- JAX implementation:
tensor_addition_jax.py
You must also download
tensor_addition_nki_kernels.py
into the same folder to run this PyTorch script.
- JAX implementation:
You can also view the source code in the Github repository nki_samples
Example usage of the scripts:#
Run NKI baremetal implementation:
python3 tensor_addition_nki_kernels.py
Run PyTorch implementation:
python3 tensor_addition_torch.py
Run JAX implementation:
python3 tensor_addition_jax.py
This document is relevant for: Inf2
, Trn1
, Trn1n