This document is relevant for: Inf2
, Trn1
, Trn1n
Transpose2D#
In this tutorial, we transpose a tensor along two of its axes using NKI. In doing so, we learn about:
The NKI syntax and programming model.
Multi-dimensional memory address patterns in NKI.
As background, there are two main types of transposition in NKI:
Transposition between the partition-dimension axis and one of the free-dimension axes, which is achieved via the
nki.isa.nc_transpose
instruction.Transposition between two axes on the free-dimension, which is achieved via a
nki.language.copy
instruction, with indexing manipulation in the free axis to re-arrange the data.
In this example, we’ll focus on the second case: consider a
three-dimensional input tensor [P, F1, F2]
, where the P
axis is mapped
to the different SBUF partitions and the F1
and F2
axes are
flattened and placed in each partition, with F1
being the major
dimension. Our goal in this example is to transpose the F1
and
F2
axes with a parallel dimension P
,
to re-arrange the data within each partition. Figure
below illustrates the input and output tensor layouts.
PyTorch#
Compute kernel#
1import neuronxcc.nki.language as nl
2
3
4def tensor_transpose2D_kernel_(in_tensor, out_tensor, shape2D):
5 """
6 NKI kernel to reorder the elements on axis[1] of the input tensor.
7
8 Every row of the input tensor is a flattened row-major 2D matrix.
9 The shape2D argument defines the dimensions of the flattened matrices (#rows,#cols).
10 Our goal in this kernel is to transpose these flattened 2D matrices, i.e. make them (#cols,#rows).
11
12 Example:
13 in_tensor = [a0,a1,a2,a3,b0,b1,b2,b3,c0,c1,c2,c3]
14 shape2D = (3,4)
15 this means that in_tensor has 3 rows and 4 columns, i.e. can be represented as:
16 [a0,a1,a2,a3]
17 [b0,b1,b2,b3]
18 [c0,c1,c2,c3]
19 after transpose, we expect to get:
20 [a0,b0,c0]
21 [a1,b1,c1]
22 [a2,b2,c2]
23 [a3,b3,c3]
24 Thus, out_tensor is expected to be [a0,b0,c0,a1,b1,c1,a2,b2,c2,a3,b3,c3]
25
26 Args:
27 in_tensor: an input tensor
28 shape2D: tuple representing the dimensions to be transposed: (#rows, #cols)
29 out_tensor: an output (transposed) tensor
30 """
31 # Gather input shapes
32 sz_p, _ = in_tensor.shape
33
34 # Load input data from external memory to on-chip memory
35 in_tile = nl.load(in_tensor)
36
37 # Performing f1/f2 transpose
38 # ==========================
39 # The desired transpose pattern is provided as an input:
40 sz_f1, sz_f2 = shape2D
41
42 # We're going to need 3 indices to perform f1:f2 transpose.
43 # - i_p0 is the parallel index
44 # - i_f1 and i_f2 are both free-dim indices, and will be used to transpose between the f1/f2 axes
45 i_p0 = nl.arange(sz_p)[:, None, None]
46 i_f1 = nl.arange(sz_f1)[None, :, None]
47 i_f2 = nl.arange(sz_f2)[None, None, :]
48
49 # Perform the transposition via a SBUF-to-SBUF copy, with access-pattern manipulation
50 # Note that we have 2D tensors and 3 indices, since we need to represent a 2D access pattern *per partition*
51 # RHS traverses an F1 x F2 matrix in a row major manner
52 # LHS traverses an F2 x F1 (new) matrix in a row major manner
53 out_tile = nl.ndarray(shape=(sz_p, sz_f2*sz_f1), dtype=out_tensor.dtype)
54 out_tile[i_p0, i_f2*sz_f1+i_f1] = nl.copy(in_tile[i_p0, i_f1*sz_f2+i_f2])
55
56 # Finally, we store out_tile to external memory
57 nl.store(out_tensor, value=out_tile)
Launching kernel and testing correctness#
To execute the kernel, we prepare tensors a
and an empty tensor
a_t_nki
, and call the tensor_transpose2D_kernel_
:
1import torch
2from torch_xla.core import xla_model as xm
3from torch_neuronx import nki_jit
4if __name__ == "__main__":
5 device = xm.xla_device()
6
7 P, X, Y = 5, 3, 4
8 a = torch.arange(P*X*Y, dtype=torch.int8).reshape((P, X*Y)).to(device=device)
9 a_t_nki = torch.zeros((P, Y*X), dtype=torch.int8).to(device=device)
10
11 tensor_transpose2D_kernel_torch = nki_jit(tensor_transpose2D_kernel_)
12 tensor_transpose2D_kernel_torch(a, a_t_nki, (X, Y))
13
14 a_t_torch = torch.transpose(a.reshape(P, X, Y), 1, 2).reshape(P, X * Y)
15
16 print(a, a_t_nki, a_t_torch)
17
18 allclose = torch.allclose(a_t_torch, a_t_nki)
19 if allclose:
20 print("NKI and PyTorch match")
21 else:
22 print("NKI and PyTorch differ")
23
24 assert allclose
JAX#
Compute kernel#
We can reuse the same NKI compute kernel defined for PyTorch above.
1import neuronxcc.nki.language as nl
2
3
4def tensor_transpose2D_kernel_(in_tensor, out_tensor, shape2D):
5 """
6 NKI kernel to reorder the elements on axis[1] of the input tensor.
7
8 Every row of the input tensor is a flattened row-major 2D matrix.
9 The shape2D argument defines the dimensions of the flattened matrices (#rows,#cols).
10 Our goal in this kernel is to transpose these flattened 2D matrices, i.e. make them (#cols,#rows).
11
12 Example:
13 in_tensor = [a0,a1,a2,a3,b0,b1,b2,b3,c0,c1,c2,c3]
14 shape2D = (3,4)
15 this means that in_tensor has 3 rows and 4 columns, i.e. can be represented as:
16 [a0,a1,a2,a3]
17 [b0,b1,b2,b3]
18 [c0,c1,c2,c3]
19 after transpose, we expect to get:
20 [a0,b0,c0]
21 [a1,b1,c1]
22 [a2,b2,c2]
23 [a3,b3,c3]
24 Thus, out_tensor is expected to be [a0,b0,c0,a1,b1,c1,a2,b2,c2,a3,b3,c3]
25
26 Args:
27 in_tensor: an input tensor
28 shape2D: tuple representing the dimensions to be transposed: (#rows, #cols)
29 out_tensor: an output (transposed) tensor
30 """
31 # Gather input shapes
32 sz_p, _ = in_tensor.shape
33
34 # Load input data from external memory to on-chip memory
35 in_tile = nl.load(in_tensor)
36
37 # Performing f1/f2 transpose
38 # ==========================
39 # The desired transpose pattern is provided as an input:
40 sz_f1, sz_f2 = shape2D
41
42 # We're going to need 3 indices to perform f1:f2 transpose.
43 # - i_p0 is the parallel index
44 # - i_f1 and i_f2 are both free-dim indices, and will be used to transpose between the f1/f2 axes
45 i_p0 = nl.arange(sz_p)[:, None, None]
46 i_f1 = nl.arange(sz_f1)[None, :, None]
47 i_f2 = nl.arange(sz_f2)[None, None, :]
48
49 # Perform the transposition via a SBUF-to-SBUF copy, with access-pattern manipulation
50 # Note that we have 2D tensors and 3 indices, since we need to represent a 2D access pattern *per partition*
51 # RHS traverses an F1 x F2 matrix in a row major manner
52 # LHS traverses an F2 x F1 (new) matrix in a row major manner
53 out_tile = nl.ndarray(shape=(sz_p, sz_f2*sz_f1), dtype=out_tensor.dtype)
54 out_tile[i_p0, i_f2*sz_f1+i_f1] = nl.copy(in_tile[i_p0, i_f1*sz_f2+i_f2])
55
56 # Finally, we store out_tile to external memory
57 nl.store(out_tensor, value=out_tile)
We also define transpose2D
as a caller to the NKI kernel. We create
a partial function partial(tensor_transpose2D_kernel_, shape2D=shape2D)
in order for JAX to be able to pass a Python object to the kernel function.
1import jax
2from functools import partial
3from jax_neuronx import nki_call
4
5
6def transpose2D(in_tensor, shape2D):
7 return nki_call(
8 partial(tensor_transpose2D_kernel_, shape2D=shape2D),
9 in_tensor,
10 out_shape=jax.ShapeDtypeStruct(in_tensor.shape, dtype=in_tensor.dtype)
Launching kernel and testing correctness#
To execute the kernel, we prepare array a
and
call the caller function transpose2D
:
1import jax
2import jax.numpy as jnp
3if __name__ == "__main__":
4 P, X, Y = 5, 37, 44
5 a = jax.random.uniform(jax.random.PRNGKey(42), (P, X * Y))
6 a_t_nki = transpose2D(a, (X, Y))
7
8 a_t_jax = jnp.transpose(a.reshape(P, X, Y), axes=(0, 2, 1)).reshape(P, X * Y)
9 print(a, a_t_nki, a_t_jax)
10
11 allclose = jnp.allclose(a_t_jax, a_t_nki)
12 if allclose:
13 print("NKI and JAX match")
14 else:
15 print("NKI and JAX differ")
16
17 assert allclose
Download All Source Code#
Click the links to download source code of the kernels and the testing code discussed in this tutorial.
NKI baremetal implementation:
transpose2d_nki_kernels.py
- PyTorch implementation:
transpose2d_torch.py
You must also download
transpose2d_nki_kernels.py
into the same folder to run this PyTorch script.
- PyTorch implementation:
- JAX implementation:
transpose2d_jax.py
You must also download
transpose2d_nki_kernels.py
into the same folder to run this JAX script.
- JAX implementation:
You can also view the source code in the Github repository nki_samples
Example usage of the scripts:#
Run NKI baremetal implementation:
python3 transpose2d_nki_kernels.py
Run PyTorch implementation:
python3 transpose2d_torch.py
Run JAX implementation:
python3 transpose2d_jax.py
This document is relevant for: Inf2
, Trn1
, Trn1n