This document is relevant for: Inf2
, Trn1
, Trn1n
AveragePool2D#
In this tutorial, we examine a case of dimensionality reduction. We implement a 2D AveragePool operation, which is used in many vision neural networks. In doing so, we learn about:
NKI syntax and programming model.
multi-dimensional memory access patterns in NKI.
The 2D AveragePool operation takes
C x [H,W]
matrices and reduces each matrix along the H
and W
axes. To leverage free-dimension flexible indexing, we can map the C
(parallel) axis to the P
dimension and H/W
(contraction)
axes to the F
dimension.
Performing such a 2D pooling operation requires a 4D memory access
pattern in the F
dimension, with reduction along two axes.
Figure
below illustrates the input and output tensor layouts.
PyTorch#
Compute kernel#
1import neuronxcc.nki.language as nl
2
3def tensor_avgpool_kernel_(in_tensor, out_tensor, pool_size):
4 """NKI kernel to compute a 2D avg-pool operation
5
6 Args:
7 in_tensor: an input tensor, of shape C x H x W
8 pool_size: an integer representing a (square) pool-window size
9 out_tensor: the resulting output tensor, of shape C x (H/pool_size) x (W/pool_size)
10 """
11
12 # Get input/output dimensions
13 sz_cin, sz_hin, sz_win = in_tensor.shape
14 sz_cout, sz_hout, sz_wout = out_tensor.shape
15 assert sz_cin == sz_cout
16
17 # Set relevant sizes
18 sz_p = sz_cin
19 sz_pool = pool_size
20
21 # Generate tensor h/w index patterns
22 # 3D indexing according to [C, H, W]
23 i_p = nl.arange(sz_p)[:, None, None] # 3D for
24 i_win = nl.arange(sz_win)[None, None, :]
25 i_hin = nl.arange(sz_hin)[None, :, None]
26
27 i_wout = nl.arange(sz_wout)[None, None, :]
28 i_hout = nl.arange(sz_hout)[None, :, None]
29
30 # Generate pool index patterns (requires two extra dimensions, for the pool window)
31 i_0 = nl.arange(sz_p)[:, None, None, None, None] #
32 i_1 = nl.arange(sz_hin//sz_pool)[None, :, None, None, None] # y_outer
33 i_2 = nl.arange(sz_pool)[None, None, :, None, None] # y_inner
34 i_3 = nl.arange(sz_win//sz_pool)[None, None, None, :, None] # x_outer
35 i_4 = nl.arange(sz_pool)[None, None, None, None, :] # x_inner
36
37 # Load input data from external memory to on-chip memory
38 # Declare ndarray to force a 3D tensor (temporary requirement)
39 in_tile = nl.ndarray([sz_p, sz_hin, sz_win], dtype=in_tensor.dtype)
40 in_tile[:,:,:] = nl.load(in_tensor[i_p, i_hin, i_win])
41
42 # Perform the pooling operation:
43 # We use numpy's advanced indexing, in order to extend in_tile to 5D, and then reduce-average two dimension.
44 # axis[0] is the index for p_dim, and thus doesn't participate in the reduction operation.
45 # axis[1] and axis[2] together index the rows, with axis[2] responsible for inner strides
46 # (i.e. inside a pooling window), and axis[1] responsible for the outer strides. As such, we reduce over axis[2].
47 # Similarly, axis[3] and axis[4] together index the columns, and we thus reduce over axis[4].
48 out_tile = nl.sum(in_tile[i_0, sz_pool*i_1+i_2, sz_pool*i_3+i_4], axis=[2,4]) / (pool_size*pool_size)
49
50 # Store the results back to external memory
51 nl.store(out_tensor[i_p, i_hout, i_wout], value=out_tile)
Launching kernel and testing correctness#
To execute the kernel, we prepare tensors in_tensor
and an empty tensor
out_nki
, and call tensor_avgpool_kernel_
:
1import torch
2from torch_neuronx import nki_jit
3from torch_xla.core import xla_model as xm
4
5if __name__ == "__main__":
6 device = xm.xla_device()
7
8 # Now let's run the kernel
9 POOL_SIZE = 2
10 C, HIN, WIN = 2, 6, 6
11 HOUT, WOUT = HIN//POOL_SIZE, WIN//POOL_SIZE
12
13 in_tensor = torch.arange(C * HIN * WIN, dtype=torch.bfloat16).reshape(C, HIN, WIN).to(device=device)
14 out_nki = torch.zeros((C, HOUT, WOUT), dtype=torch.bfloat16).to(device=device)
15
16 tensor_avgpool_kernel_torch = nki_jit(tensor_avgpool_kernel_)
17 tensor_avgpool_kernel_torch(in_tensor, out_nki, POOL_SIZE)
18
19 out_torch = torch.nn.functional.avg_pool2d(in_tensor, POOL_SIZE, POOL_SIZE)
20
21 print(in_tensor, out_nki, out_torch) # an implicit XLA barrier/mark-step
22
23 if (out_nki == out_torch).all():
24 print("NKI and Torch match")
25 else:
26 print("NKI and Torch differ")
JAX#
Compute kernel#
Let’s reuse the same NKI kernel implementation defined for PyTorch above:
1import neuronxcc.nki.language as nl
2
3def tensor_avgpool_kernel_(in_tensor, out_tensor, pool_size):
4 """NKI kernel to compute a 2D avg-pool operation
5
6 Args:
7 in_tensor: an input tensor, of shape C x H x W
8 pool_size: an integer representing a (square) pool-window size
9 out_tensor: the resulting output tensor, of shape C x (H/pool_size) x (W/pool_size)
10 """
11
12 # Get input/output dimensions
13 sz_cin, sz_hin, sz_win = in_tensor.shape
14 sz_cout, sz_hout, sz_wout = out_tensor.shape
15 assert sz_cin == sz_cout
16
17 # Set relevant sizes
18 sz_p = sz_cin
19 sz_pool = pool_size
20
21 # Generate tensor h/w index patterns
22 # 3D indexing according to [C, H, W]
23 i_p = nl.arange(sz_p)[:, None, None] # 3D for
24 i_win = nl.arange(sz_win)[None, None, :]
25 i_hin = nl.arange(sz_hin)[None, :, None]
26
27 i_wout = nl.arange(sz_wout)[None, None, :]
28 i_hout = nl.arange(sz_hout)[None, :, None]
29
30 # Generate pool index patterns (requires two extra dimensions, for the pool window)
31 i_0 = nl.arange(sz_p)[:, None, None, None, None] #
32 i_1 = nl.arange(sz_hin//sz_pool)[None, :, None, None, None] # y_outer
33 i_2 = nl.arange(sz_pool)[None, None, :, None, None] # y_inner
34 i_3 = nl.arange(sz_win//sz_pool)[None, None, None, :, None] # x_outer
35 i_4 = nl.arange(sz_pool)[None, None, None, None, :] # x_inner
36
37 # Load input data from external memory to on-chip memory
38 # Declare ndarray to force a 3D tensor (temporary requirement)
39 in_tile = nl.ndarray([sz_p, sz_hin, sz_win], dtype=in_tensor.dtype)
40 in_tile[:,:,:] = nl.load(in_tensor[i_p, i_hin, i_win])
41
42 # Perform the pooling operation:
43 # We use numpy's advanced indexing, in order to extend in_tile to 5D, and then reduce-average two dimension.
44 # axis[0] is the index for p_dim, and thus doesn't participate in the reduction operation.
45 # axis[1] and axis[2] together index the rows, with axis[2] responsible for inner strides
46 # (i.e. inside a pooling window), and axis[1] responsible for the outer strides. As such, we reduce over axis[2].
47 # Similarly, axis[3] and axis[4] together index the columns, and we thus reduce over axis[4].
48 out_tile = nl.sum(in_tile[i_0, sz_pool*i_1+i_2, sz_pool*i_3+i_4], axis=[2,4]) / (pool_size*pool_size)
49
50 # Store the results back to external memory
51 nl.store(out_tensor[i_p, i_hout, i_wout], value=out_tile)
We define tensor_avgpool_kernel
as a caller to the NKI kernel. We create
a partial function partial(tensor_avgpool_kernel_, pool_size=pool_size)
in order for JAX to be able to pass a Python object to the kernel function.
1from functools import partial
2import jax
3
4def tensor_avgpool_kernel(in_array, pool_size):
5 return nki_call(
6 partial(tensor_avgpool_kernel_, pool_size=pool_size),
7 in_array,
8 out_shape=jax.ShapeDtypeStruct((C, HOUT, WOUT), dtype=in_array.dtype),
9 )
We write a reference JAX implementation of AveragePool2D
as JAX does
not have a primitive for it.
1import jax.numpy as jnp
2
3# Reference JAX implementation
4def jax_average_pool_2D(in_tensor, pool_size):
5 c, h_in, w_in = in_tensor.shape
6 reshaped = in_tensor.reshape(c, h_in // pool_size, pool_size, w_in // pool_size, pool_size)
7 return jnp.nanmean(reshaped, axis=(2, 4))
Launching kernel and testing correctness#
To execute the kernel, we prepare array in_array
and invoke the kernel caller function tensor_avgpool_kernel
:
1if __name__ == "__main__":
2 POOL_SIZE = 2
3 C, HIN, WIN = 2, 6, 6
4 HOUT, WOUT = HIN//POOL_SIZE, WIN//POOL_SIZE
5
6 in_array = jnp.arange(C * HIN * WIN, dtype=jnp.float32).reshape(C, HIN, WIN)
7
8 out_nki = tensor_avgpool_kernel(in_array, pool_size=POOL_SIZE)
9 out_jax = jax_average_pool_2D(in_array, pool_size=POOL_SIZE)
10
11 print(in_array, out_nki, out_jax)
12
13 if jnp.allclose(out_nki, out_jax):
14 print("NKI and JAX match")
15 else:
16 print("NKI and JAX differ")
Download All Source Code#
Click the links to download source code of the kernels and the testing code discussed in this tutorial.
NKI baremetal implementation:
average_pool2d_nki_kernels.py
- PyTorch implementation:
average_pool2d_torch.py
You must also download
average_pool2d_nki_kernels.py
into the same folder to run this PyTorch script.
- PyTorch implementation:
- JAX implementation:
average_pool2d_jax.py
You must also download
average_pool2d_nki_kernels.py
into the same folder to run this JAX script.
- JAX implementation:
You can also view the source code in the Github repository nki_samples
Example usage of the scripts:#
Run NKI baremetal implementation:
python3 average_pool2d_nki_kernels.py
Run PyTorch implementation:
python3 average_pool2d_torch.py
Run JAX implementation:
python3 average_pool2d_jax.py
This document is relevant for: Inf2
, Trn1
, Trn1n