This document is relevant for: Inf2
, Trn1
, Trn2
AveragePool2D#
In this tutorial, we examine a case of dimensionality reduction. We implement a 2D AveragePool operation, which is used in many vision neural networks. In doing so, we learn about:
NKI syntax and programming model.
multi-dimensional memory access patterns in NKI.
The 2D AveragePool operation takes
C x [H,W]
matrices and reduces each matrix along the H
and W
axes. To leverage free-dimension flexible indexing, we can map the C
(parallel) axis to the P
dimension and H/W
(contraction)
axes to the F
dimension.
Performing such a 2D pooling operation requires a 4D memory access
pattern in the F
dimension, with reduction along two axes.
Figure
below illustrates the input and output tensor layouts.
PyTorch#
Compute kernel#
1import neuronxcc.nki as nki
2import neuronxcc.nki.language as nl
3from neuronxcc.nki.typing import tensor
4
5@nki.jit
6def tensor_avgpool_kernel(in_tensor, pool_size):
7 """NKI kernel to compute a 2D avg-pool operation
8
9 Args:
10 in_tensor: an input tensor, of shape C x H x W
11 pool_size: an integer representing a (square) pool-window size
12
13 Return:
14 out_tensor: the resulting output tensor, of shape C x (H/pool_size) x (W/pool_size)
15 """
16
17 # Get input/output dimensions
18 sz_cin, sz_hin, sz_win = in_tensor.shape
19 sz_hout = sz_hin // pool_size
20 sz_wout = sz_win // pool_size
21 # Create output tensor shared between all SPMD instances as result tensor
22 out_tensor = nl.ndarray((sz_cin, sz_hout, sz_wout), dtype=in_tensor.dtype,
23 buffer=nl.shared_hbm)
24
25 # Set relevant sizes
26 sz_p = sz_cin
27 sz_pool = pool_size
28
29 # Generate pool index patterns (requires two extra dimensions, for the pool window)
30 i0, i1, i2, i3, i4 = nl.mgrid[0:sz_p, 0:sz_hin//sz_pool, 0:sz_pool, 0:sz_win//sz_pool, 0:sz_pool]
31
32 # Load input data from external memory to on-chip memory
33 in_tile: tensor[sz_p, sz_hin, sz_win] = nl.load(in_tensor)
34
35 # Perform the pooling operation:
36 # We use numpy's advanced indexing, in order to extend in_tile to 5D, and then reduce-average two dimension.
37 # axis[0] is the index for p_dim, and thus doesn't participate in the reduction operation.
38 # axis[1] and axis[2] together index the rows, with axis[2] responsible for inner strides
39 # (i.e. inside a pooling window), and axis[1] responsible for the outer strides. As such, we reduce over axis[2].
40 # Similarly, axis[3] and axis[4] together index the columns, and we thus reduce over axis[4].
41 out_tile : tensor[sz_p, sz_hout, sz_wout] = nl.sum(in_tile[i0, sz_pool*i1+i2, sz_pool*i3+i4],
42 axis=[2,4]) / (pool_size*pool_size)
43
44 # Store the results back to hbm
45 nl.store(out_tensor, value=out_tile)
46
47 # Transfer the ownership of `out_tensor` to the caller
48 return out_tensor
Launching kernel and testing correctness#
To execute the kernel, we prepare tensors in_tensor
and call tensor_avgpool_kernel
:
1import torch
2from torch_xla.core import xla_model as xm
3
4if __name__ == "__main__":
5 device = xm.xla_device()
6
7 # Now let's run the kernel
8 POOL_SIZE = 2
9 C, HIN, WIN = 2, 6, 6
10 HOUT, WOUT = HIN//POOL_SIZE, WIN//POOL_SIZE
11
12 in_tensor = torch.arange(C * HIN * WIN, dtype=torch.bfloat16).reshape(C, HIN, WIN).to(device=device)
13 out_nki = torch.zeros((C, HOUT, WOUT), dtype=torch.bfloat16).to(device=device)
14
15 out_nki = tensor_avgpool_kernel(in_tensor, POOL_SIZE)
16
17 out_torch = torch.nn.functional.avg_pool2d(in_tensor, POOL_SIZE, POOL_SIZE)
18
19 print(in_tensor, out_nki, out_torch) # an implicit XLA barrier/mark-step
20
21 if (out_nki == out_torch).all():
22 print("NKI and Torch match")
23 else:
24 print("NKI and Torch differ")
JAX#
Compute kernel#
Let’s reuse the same NKI kernel implementation defined for PyTorch above:
1import neuronxcc.nki as nki
2import neuronxcc.nki.language as nl
3from neuronxcc.nki.typing import tensor
4
5@nki.jit
6def tensor_avgpool_kernel(in_tensor, pool_size):
7 """NKI kernel to compute a 2D avg-pool operation
8
9 Args:
10 in_tensor: an input tensor, of shape C x H x W
11 pool_size: an integer representing a (square) pool-window size
12
13 Return:
14 out_tensor: the resulting output tensor, of shape C x (H/pool_size) x (W/pool_size)
15 """
16
17 # Get input/output dimensions
18 sz_cin, sz_hin, sz_win = in_tensor.shape
19 sz_hout = sz_hin // pool_size
20 sz_wout = sz_win // pool_size
21 # Create output tensor shared between all SPMD instances as result tensor
22 out_tensor = nl.ndarray((sz_cin, sz_hout, sz_wout), dtype=in_tensor.dtype,
23 buffer=nl.shared_hbm)
24
25 # Set relevant sizes
26 sz_p = sz_cin
27 sz_pool = pool_size
28
29 # Generate pool index patterns (requires two extra dimensions, for the pool window)
30 i0, i1, i2, i3, i4 = nl.mgrid[0:sz_p, 0:sz_hin//sz_pool, 0:sz_pool, 0:sz_win//sz_pool, 0:sz_pool]
31
32 # Load input data from external memory to on-chip memory
33 in_tile: tensor[sz_p, sz_hin, sz_win] = nl.load(in_tensor)
34
35 # Perform the pooling operation:
36 # We use numpy's advanced indexing, in order to extend in_tile to 5D, and then reduce-average two dimension.
37 # axis[0] is the index for p_dim, and thus doesn't participate in the reduction operation.
38 # axis[1] and axis[2] together index the rows, with axis[2] responsible for inner strides
39 # (i.e. inside a pooling window), and axis[1] responsible for the outer strides. As such, we reduce over axis[2].
40 # Similarly, axis[3] and axis[4] together index the columns, and we thus reduce over axis[4].
41 out_tile : tensor[sz_p, sz_hout, sz_wout] = nl.sum(in_tile[i0, sz_pool*i1+i2, sz_pool*i3+i4],
42 axis=[2,4]) / (pool_size*pool_size)
43
44 # Store the results back to hbm
45 nl.store(out_tensor, value=out_tile)
46
47 # Transfer the ownership of `out_tensor` to the caller
48 return out_tensor
In order to pass pool_size
as a compile time constant, we pass pool_size
as kwargs.
out_nki = tensor_avgpool_kernel(in_array, pool_size=POOL_SIZE)
We write a reference JAX implementation of AveragePool2D
as JAX does
not have a primitive for it.
1import jax.numpy as jnp
2
3# Reference JAX implementation
4def jax_average_pool_2D(in_tensor, pool_size):
5 c, h_in, w_in = in_tensor.shape
6 reshaped = in_tensor.reshape(c, h_in // pool_size, pool_size, w_in // pool_size, pool_size)
7 return jnp.nanmean(reshaped, axis=(2, 4))
Launching kernel and testing correctness#
To execute the kernel, we prepare array in_array
and invoke the kernel caller function tensor_avgpool_kernel
:
1if __name__ == "__main__":
2 POOL_SIZE = 2
3 C, HIN, WIN = 2, 6, 6
4 HOUT, WOUT = HIN//POOL_SIZE, WIN//POOL_SIZE
5
6 in_array = jnp.arange(C * HIN * WIN, dtype=jnp.float32).reshape(C, HIN, WIN)
7
8 out_nki = tensor_avgpool_kernel(in_array, pool_size=POOL_SIZE)
9 out_jax = jax_average_pool_2D(in_array, pool_size=POOL_SIZE)
10
11 print(in_array, out_nki, out_jax)
12
13 if jnp.allclose(out_nki, out_jax):
14 print("NKI and JAX match")
15 else:
16 print("NKI and JAX differ")
Download All Source Code#
Click the links to download source code of the kernels and the testing code discussed in this tutorial.
NKI baremetal implementation:
average_pool2d_nki_kernels.py
- PyTorch implementation:
average_pool2d_torch.py
You must also download
average_pool2d_nki_kernels.py
into the same folder to run this PyTorch script.
- PyTorch implementation:
- JAX implementation:
average_pool2d_jax.py
You must also download
average_pool2d_nki_kernels.py
into the same folder to run this JAX script.
- JAX implementation:
You can also view the source code in the GitHub repository nki_samples
Example usage of the scripts:#
Run NKI baremetal implementation:
python3 average_pool2d_nki_kernels.py
Run PyTorch implementation:
python3 average_pool2d_torch.py
Run JAX implementation:
python3 average_pool2d_jax.py
This document is relevant for: Inf2
, Trn1
, Trn2