This document is relevant for: Inf2, Trn1, Trn1n


In this tutorial, we examine a case of dimensionality reduction. We implement a 2D AveragePool operation, which is used in many vision neural networks. In doing so, we learn about:

  • NKI syntax and programming model.

  • multi-dimensional memory access patterns in NKI.

The 2D AveragePool operation takes C x [H,W] matrices and reduces each matrix along the H and W axes. To leverage free-dimension flexible indexing, we can map the C (parallel) axis to the P dimension and H/W (contraction) axes to the F dimension. Performing such a 2D pooling operation requires a 4D memory access pattern in the F dimension, with reduction along two axes. Figure below illustrates the input and output tensor layouts.


Fig. 72 2D-Pooling Operation (reducing on axes F2 and F4)#


Compute kernel#

 1import neuronxcc.nki.language as nl
 3def tensor_avgpool_kernel_(in_tensor, out_tensor, pool_size):
 4  """NKI kernel to compute a 2D avg-pool operation
 6  Args:
 7      in_tensor: an input tensor, of shape C x H x W
 8      pool_size: an integer representing a (square) pool-window size
 9      out_tensor: the resulting output tensor, of shape C x (H/pool_size) x (W/pool_size)
10  """
12  # Get input/output dimensions
13  sz_cin, sz_hin, sz_win = in_tensor.shape
14  sz_cout, sz_hout, sz_wout = out_tensor.shape
15  assert sz_cin == sz_cout
17  # Set relevant sizes
18  sz_p = sz_cin
19  sz_pool = pool_size
21  # Generate tensor h/w index patterns
22  # 3D indexing according to [C, H, W]
23  i_p = nl.arange(sz_p)[:, None, None] # 3D for
24  i_win = nl.arange(sz_win)[None, None, :]
25  i_hin = nl.arange(sz_hin)[None, :, None]
27  i_wout = nl.arange(sz_wout)[None, None, :]
28  i_hout = nl.arange(sz_hout)[None, :, None]
30  # Generate pool index patterns (requires two extra dimensions, for the pool window)
31  i_0 = nl.arange(sz_p)[:, None, None, None, None] #
32  i_1 = nl.arange(sz_hin//sz_pool)[None, :, None, None, None] # y_outer
33  i_2 = nl.arange(sz_pool)[None, None, :, None, None] # y_inner
34  i_3 = nl.arange(sz_win//sz_pool)[None, None, None, :, None] # x_outer
35  i_4 = nl.arange(sz_pool)[None, None, None, None, :] # x_inner
37  # Load input data from external memory to on-chip memory
38  # Declare ndarray to force a 3D tensor (temporary requirement)
39  in_tile = nl.ndarray([sz_p, sz_hin, sz_win], dtype=in_tensor.dtype)
40  in_tile[:,:,:] = nl.load(in_tensor[i_p, i_hin, i_win])
42  # Perform the pooling operation:
43  # We use numpy's advanced indexing, in order to extend in_tile to 5D, and then reduce-average two dimension.
44  # axis[0] is the index for p_dim, and thus doesn't participate in the reduction operation.
45  # axis[1] and axis[2] together index the rows, with axis[2] responsible for inner strides
46  # (i.e. inside a pooling window), and axis[1] responsible for the outer strides. As such, we reduce over axis[2].
47  # Similarly, axis[3] and axis[4] together index the columns, and we thus reduce over axis[4].
48  out_tile = nl.sum(in_tile[i_0, sz_pool*i_1+i_2, sz_pool*i_3+i_4], axis=[2,4]) / (pool_size*pool_size)
50  # Store the results back to external memory
51[i_p, i_hout, i_wout], value=out_tile)

Launching kernel and testing correctness#

To execute the kernel, we prepare tensors in_tensor and an empty tensor out_nki, and call tensor_avgpool_kernel_:

 1import torch
 2from torch_neuronx import nki_jit
 3from torch_xla.core import xla_model as xm
 5if __name__ == "__main__":
 6  device = xm.xla_device()
 8  # Now let's run the kernel
 9  POOL_SIZE = 2
10  C, HIN, WIN = 2, 6, 6
13  in_tensor = torch.arange(C * HIN * WIN, dtype=torch.bfloat16).reshape(C, HIN, WIN).to(device=device)
14  out_nki = torch.zeros((C, HOUT, WOUT), dtype=torch.bfloat16).to(device=device)
16  tensor_avgpool_kernel_torch = nki_jit(tensor_avgpool_kernel_)
17  tensor_avgpool_kernel_torch(in_tensor, out_nki, POOL_SIZE)
19  out_torch = torch.nn.functional.avg_pool2d(in_tensor, POOL_SIZE, POOL_SIZE)
21  print(in_tensor, out_nki, out_torch) # an implicit XLA barrier/mark-step
23  if (out_nki == out_torch).all():
24    print("NKI and Torch match")
25  else:
26    print("NKI and Torch differ")


Compute kernel#

Let’s reuse the same NKI kernel implementation defined for PyTorch above:

 1import neuronxcc.nki.language as nl
 3def tensor_avgpool_kernel_(in_tensor, out_tensor, pool_size):
 4  """NKI kernel to compute a 2D avg-pool operation
 6  Args:
 7      in_tensor: an input tensor, of shape C x H x W
 8      pool_size: an integer representing a (square) pool-window size
 9      out_tensor: the resulting output tensor, of shape C x (H/pool_size) x (W/pool_size)
10  """
12  # Get input/output dimensions
13  sz_cin, sz_hin, sz_win = in_tensor.shape
14  sz_cout, sz_hout, sz_wout = out_tensor.shape
15  assert sz_cin == sz_cout
17  # Set relevant sizes
18  sz_p = sz_cin
19  sz_pool = pool_size
21  # Generate tensor h/w index patterns
22  # 3D indexing according to [C, H, W]
23  i_p = nl.arange(sz_p)[:, None, None] # 3D for
24  i_win = nl.arange(sz_win)[None, None, :]
25  i_hin = nl.arange(sz_hin)[None, :, None]
27  i_wout = nl.arange(sz_wout)[None, None, :]
28  i_hout = nl.arange(sz_hout)[None, :, None]
30  # Generate pool index patterns (requires two extra dimensions, for the pool window)
31  i_0 = nl.arange(sz_p)[:, None, None, None, None] #
32  i_1 = nl.arange(sz_hin//sz_pool)[None, :, None, None, None] # y_outer
33  i_2 = nl.arange(sz_pool)[None, None, :, None, None] # y_inner
34  i_3 = nl.arange(sz_win//sz_pool)[None, None, None, :, None] # x_outer
35  i_4 = nl.arange(sz_pool)[None, None, None, None, :] # x_inner
37  # Load input data from external memory to on-chip memory
38  # Declare ndarray to force a 3D tensor (temporary requirement)
39  in_tile = nl.ndarray([sz_p, sz_hin, sz_win], dtype=in_tensor.dtype)
40  in_tile[:,:,:] = nl.load(in_tensor[i_p, i_hin, i_win])
42  # Perform the pooling operation:
43  # We use numpy's advanced indexing, in order to extend in_tile to 5D, and then reduce-average two dimension.
44  # axis[0] is the index for p_dim, and thus doesn't participate in the reduction operation.
45  # axis[1] and axis[2] together index the rows, with axis[2] responsible for inner strides
46  # (i.e. inside a pooling window), and axis[1] responsible for the outer strides. As such, we reduce over axis[2].
47  # Similarly, axis[3] and axis[4] together index the columns, and we thus reduce over axis[4].
48  out_tile = nl.sum(in_tile[i_0, sz_pool*i_1+i_2, sz_pool*i_3+i_4], axis=[2,4]) / (pool_size*pool_size)
50  # Store the results back to external memory
51[i_p, i_hout, i_wout], value=out_tile)

We define tensor_avgpool_kernel as a caller to the NKI kernel. We create a partial function partial(tensor_avgpool_kernel_, pool_size=pool_size) in order for JAX to be able to pass a Python object to the kernel function.

1from functools import partial
2import jax
4def tensor_avgpool_kernel(in_array, pool_size):
5  return nki_call(
6    partial(tensor_avgpool_kernel_, pool_size=pool_size),
7    in_array,
8    out_shape=jax.ShapeDtypeStruct((C, HOUT, WOUT), dtype=in_array.dtype),
9  )

We write a reference JAX implementation of AveragePool2D as JAX does not have a primitive for it.

1import jax.numpy as jnp
3# Reference JAX implementation
4def jax_average_pool_2D(in_tensor, pool_size):
5  c, h_in, w_in = in_tensor.shape
6  reshaped = in_tensor.reshape(c, h_in // pool_size, pool_size, w_in // pool_size, pool_size)
7  return jnp.nanmean(reshaped, axis=(2, 4))

Launching kernel and testing correctness#

To execute the kernel, we prepare array in_array and invoke the kernel caller function tensor_avgpool_kernel:

 1if __name__ == "__main__":
 2  POOL_SIZE = 2
 3  C, HIN, WIN = 2, 6, 6
 6  in_array = jnp.arange(C * HIN * WIN, dtype=jnp.float32).reshape(C, HIN, WIN)
 8  out_nki = tensor_avgpool_kernel(in_array, pool_size=POOL_SIZE)
 9  out_jax = jax_average_pool_2D(in_array, pool_size=POOL_SIZE)
11  print(in_array, out_nki, out_jax)
13  if jnp.allclose(out_nki, out_jax):
14    print("NKI and JAX match")
15  else:
16    print("NKI and JAX differ")

Download All Source Code#

Click the links to download source code of the kernels and the testing code discussed in this tutorial.

You can also view the source code in the Github repository nki_samples

Example usage of the scripts:#

Run NKI baremetal implementation:


Run PyTorch implementation:


Run JAX implementation:


This document is relevant for: Inf2, Trn1, Trn1n