This document is relevant for: Inf2, Trn1, Trn2

AveragePool2D#

In this tutorial, we examine a case of dimensionality reduction. We implement a 2D AveragePool operation, which is used in many vision neural networks. In doing so, we learn about:

  • NKI syntax and programming model.

  • multi-dimensional memory access patterns in NKI.

The 2D AveragePool operation takes C x [H,W] matrices and reduces each matrix along the H and W axes. To leverage free-dimension flexible indexing, we can map the C (parallel) axis to the P dimension and H/W (contraction) axes to the F dimension. Performing such a 2D pooling operation requires a 4D memory access pattern in the F dimension, with reduction along two axes. Figure below illustrates the input and output tensor layouts.

../../../_images/pm-index-3.png

Fig. 76 2D-Pooling Operation (reducing on axes F2 and F4)#

PyTorch#

Compute kernel#

 1import neuronxcc.nki as nki
 2import neuronxcc.nki.language as nl
 3from neuronxcc.nki.typing import tensor
 4
 5@nki.jit
 6def tensor_avgpool_kernel(in_tensor, pool_size):
 7  """NKI kernel to compute a 2D avg-pool operation
 8
 9  Args:
10      in_tensor: an input tensor, of shape C x H x W
11      pool_size: an integer representing a (square) pool-window size
12
13  Return:
14      out_tensor: the resulting output tensor, of shape C x (H/pool_size) x (W/pool_size)
15  """
16
17  # Get input/output dimensions
18  sz_cin, sz_hin, sz_win = in_tensor.shape
19  sz_hout = sz_hin // pool_size
20  sz_wout = sz_win // pool_size
21  # Create output tensor shared between all SPMD instances as result tensor
22  out_tensor = nl.ndarray((sz_cin, sz_hout, sz_wout), dtype=in_tensor.dtype,
23                          buffer=nl.shared_hbm)
24
25  # Set relevant sizes
26  sz_p = sz_cin
27  sz_pool = pool_size
28
29  # Generate pool index patterns (requires two extra dimensions, for the pool window)
30  i0, i1, i2, i3, i4 = nl.mgrid[0:sz_p, 0:sz_hin//sz_pool, 0:sz_pool, 0:sz_win//sz_pool, 0:sz_pool]
31
32  # Load input data from external memory to on-chip memory
33  in_tile: tensor[sz_p, sz_hin, sz_win] = nl.load(in_tensor)
34
35  # Perform the pooling operation:
36  # We use numpy's advanced indexing, in order to extend in_tile to 5D, and then reduce-average two dimension.
37  # axis[0] is the index for p_dim, and thus doesn't participate in the reduction operation.
38  # axis[1] and axis[2] together index the rows, with axis[2] responsible for inner strides
39  # (i.e. inside a pooling window), and axis[1] responsible for the outer strides. As such, we reduce over axis[2].
40  # Similarly, axis[3] and axis[4] together index the columns, and we thus reduce over axis[4].
41  out_tile : tensor[sz_p, sz_hout, sz_wout] = nl.sum(in_tile[i0, sz_pool*i1+i2, sz_pool*i3+i4],
42                                                     axis=[2,4]) / (pool_size*pool_size)
43
44  # Store the results back to hbm
45  nl.store(out_tensor, value=out_tile)
46
47  # Transfer the ownership of `out_tensor` to the caller
48  return out_tensor

Launching kernel and testing correctness#

To execute the kernel, we prepare tensors in_tensor and call tensor_avgpool_kernel:

 1import torch
 2from torch_xla.core import xla_model as xm
 3
 4if __name__ == "__main__":
 5  device = xm.xla_device()
 6
 7  # Now let's run the kernel
 8  POOL_SIZE = 2
 9  C, HIN, WIN = 2, 6, 6
10  HOUT, WOUT = HIN//POOL_SIZE, WIN//POOL_SIZE
11
12  in_tensor = torch.arange(C * HIN * WIN, dtype=torch.bfloat16).reshape(C, HIN, WIN).to(device=device)
13  out_nki = torch.zeros((C, HOUT, WOUT), dtype=torch.bfloat16).to(device=device)
14
15  out_nki = tensor_avgpool_kernel(in_tensor, POOL_SIZE)
16
17  out_torch = torch.nn.functional.avg_pool2d(in_tensor, POOL_SIZE, POOL_SIZE)
18
19  print(in_tensor, out_nki, out_torch) # an implicit XLA barrier/mark-step
20
21  if (out_nki == out_torch).all():
22    print("NKI and Torch match")
23  else:
24    print("NKI and Torch differ")

JAX#

Compute kernel#

Let’s reuse the same NKI kernel implementation defined for PyTorch above:

 1import neuronxcc.nki as nki
 2import neuronxcc.nki.language as nl
 3from neuronxcc.nki.typing import tensor
 4
 5@nki.jit
 6def tensor_avgpool_kernel(in_tensor, pool_size):
 7  """NKI kernel to compute a 2D avg-pool operation
 8
 9  Args:
10      in_tensor: an input tensor, of shape C x H x W
11      pool_size: an integer representing a (square) pool-window size
12
13  Return:
14      out_tensor: the resulting output tensor, of shape C x (H/pool_size) x (W/pool_size)
15  """
16
17  # Get input/output dimensions
18  sz_cin, sz_hin, sz_win = in_tensor.shape
19  sz_hout = sz_hin // pool_size
20  sz_wout = sz_win // pool_size
21  # Create output tensor shared between all SPMD instances as result tensor
22  out_tensor = nl.ndarray((sz_cin, sz_hout, sz_wout), dtype=in_tensor.dtype,
23                          buffer=nl.shared_hbm)
24
25  # Set relevant sizes
26  sz_p = sz_cin
27  sz_pool = pool_size
28
29  # Generate pool index patterns (requires two extra dimensions, for the pool window)
30  i0, i1, i2, i3, i4 = nl.mgrid[0:sz_p, 0:sz_hin//sz_pool, 0:sz_pool, 0:sz_win//sz_pool, 0:sz_pool]
31
32  # Load input data from external memory to on-chip memory
33  in_tile: tensor[sz_p, sz_hin, sz_win] = nl.load(in_tensor)
34
35  # Perform the pooling operation:
36  # We use numpy's advanced indexing, in order to extend in_tile to 5D, and then reduce-average two dimension.
37  # axis[0] is the index for p_dim, and thus doesn't participate in the reduction operation.
38  # axis[1] and axis[2] together index the rows, with axis[2] responsible for inner strides
39  # (i.e. inside a pooling window), and axis[1] responsible for the outer strides. As such, we reduce over axis[2].
40  # Similarly, axis[3] and axis[4] together index the columns, and we thus reduce over axis[4].
41  out_tile : tensor[sz_p, sz_hout, sz_wout] = nl.sum(in_tile[i0, sz_pool*i1+i2, sz_pool*i3+i4],
42                                                     axis=[2,4]) / (pool_size*pool_size)
43
44  # Store the results back to hbm
45  nl.store(out_tensor, value=out_tile)
46
47  # Transfer the ownership of `out_tensor` to the caller
48  return out_tensor

In order to pass pool_size as a compile time constant, we pass pool_size as kwargs.

out_nki = tensor_avgpool_kernel(in_array, pool_size=POOL_SIZE)

We write a reference JAX implementation of AveragePool2D as JAX does not have a primitive for it.

1import jax.numpy as jnp
2
3# Reference JAX implementation
4def jax_average_pool_2D(in_tensor, pool_size):
5  c, h_in, w_in = in_tensor.shape
6  reshaped = in_tensor.reshape(c, h_in // pool_size, pool_size, w_in // pool_size, pool_size)
7  return jnp.nanmean(reshaped, axis=(2, 4))

Launching kernel and testing correctness#

To execute the kernel, we prepare array in_array and invoke the kernel caller function tensor_avgpool_kernel:

 1if __name__ == "__main__":
 2  POOL_SIZE = 2
 3  C, HIN, WIN = 2, 6, 6
 4  HOUT, WOUT = HIN//POOL_SIZE, WIN//POOL_SIZE
 5
 6  in_array = jnp.arange(C * HIN * WIN, dtype=jnp.float32).reshape(C, HIN, WIN)
 7
 8  out_nki = tensor_avgpool_kernel(in_array, pool_size=POOL_SIZE)
 9  out_jax = jax_average_pool_2D(in_array, pool_size=POOL_SIZE)
10
11  print(in_array, out_nki, out_jax)
12
13  if jnp.allclose(out_nki, out_jax):
14    print("NKI and JAX match")
15  else:
16    print("NKI and JAX differ")

Download All Source Code#

Click the links to download source code of the kernels and the testing code discussed in this tutorial.

You can also view the source code in the GitHub repository nki_samples

Example usage of the scripts:#

Run NKI baremetal implementation:

python3 average_pool2d_nki_kernels.py

Run PyTorch implementation:

python3 average_pool2d_torch.py

Run JAX implementation:

python3 average_pool2d_jax.py

This document is relevant for: Inf2, Trn1, Trn2