Matrix multiplication#

In this tutorial, we will start with a simple NKI matrix multiplication kernel and optimize it step by step. In doing so, we learn about:

  • The NKI syntax and programming model.

  • Layout, tiling, and memory management considerations when performing matrix multiplication in NKI.

Basic compute kernel#

../../_images/matrix-multiplication-views.png

Fig. 78 MxKxN Matrix Multiplication Visualization#

Fig. 78 illustrates how a simple matrix multiplication: lhs [M, K] * rhs [K, N] = output [M, N] would be mapped to the Tensor Engine (TensorE) and SRAMs from its original mathematical view. Note, the PSUM partition dimension is rotated 90 degrees from SBUF partition dimension solely for layout visualization. The copy preserves the output tile layout from PSUM to SBUF, by copying data from each PSUM partition to the corresponding SBUF partition.

The NKI example below implements a compute kernel for a single-tile matrix multiplication. It computes a 64(M) x 128(K) x 512 (N) matrix multiplication operation.

 1@nki.jit
 2def nki_matmul_basic_(lhsT, rhs):
 3  """NKI kernel to compute a 64x128x512 matrix multiplication operation
 4
 5  Args:
 6      lhsT: an input tensor of shape [128,64], a left hand side argument of the
 7        matrix multiplication, delivered transposed for optimal performance
 8      rhs: an input tensor of shape [128,512], a right hand side argument of the
 9        matrix multiplication
10  Returns:
11      result: the resulting output tensor of shape [64,512]
12  """
13  # Verify that the lhsT and rhs are the expected sizes.
14  K, M = lhsT.shape
15  K_, N = rhs.shape
16
17  # Check that the contraction dimension matches and all dimensions
18  #are what were expected.
19  assert K == K_, \
20    f"Expected contraction dimension to match on both lhsT ({K}) and rhs ({K})"
21  assert K == 128, f"Expected contraction dimension to be 128, but got {K}"
22  assert M == 64, f"Expected lhsT matrix to have dimension M of 64, but got {M}"
23  assert N == 512, f"Expected rhs matrix to have dimension N of 512, but got {N}"
24
25  # Create a tensor to write the result into (not initialized)
26  result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)
27
28  # Creating a tensor in SBUF to load the inputs into (not initialized)
29  lhs_tile = nl.ndarray(lhsT.shape, dtype=lhsT.dtype, buffer=nl.sbuf)
30  rhs_tile = nl.ndarray(rhs.shape, dtype=rhs.dtype, buffer=nl.sbuf)
31
32  # Loading the inputs (HBM->SBUF)
33  # Note: here we take Tile dtype definition into account,
34  # which forces P-dim as the left most index
35  nisa.dma_copy(dst=lhs_tile, src=lhsT)
36  nisa.dma_copy(dst=rhs_tile, src=rhs)
37
38  # Create a tensor in PSUM to accumulate the result in (uninitialized)
39  result_psum = nl.ndarray(result.shape, dtype=nl.float32, buffer=nl.psum)
40
41  # Perform the matrix-multiplication
42  # Note: A NKI matmul instruction always writes to PSUM in float32 data-type
43  nisa.nc_matmul(result_psum, lhs_tile, rhs_tile)
44
45  # Create a tensor in SBUF and copy the result from PSUM back to SBUF, 
46  # and cast to expected output data-type
47  result_sbuf = nl.ndarray(result_psum.shape, dtype=result.dtype, buffer=nl.sbuf)
48  nisa.tensor_copy(dst=result_sbuf, src=result_psum, dtype=result.dtype)
49
50  # The result of [64,128] x [128,512] matrix multiplication has a shape of [64, 512].
51  # This dictates which indices to use to address the result tile.
52  nisa.dma_copy(dst=result, src=result_sbuf)
53
54  return result

In this example, we define the NKI kernel as nki_matmul_basic_:

  1. We define indices to access the LHS and RHS input tensors.

  2. To adhere to NKI’s layout considerations (Layout Considerations), we map the contraction axis of both LHS and RHS to the P-dimension, which means we load LHS in transposed form.

  3. To adhere to NKI’s tile size considerations (Tile Size Considerations), we limit the matmul instruction arguments to tiles of up to [128,128] for LHS, and [128,512] for RHS.

  4. Using the nisa.dma_copy operation, we load the inputs from HBM tensors to SBUF tiles.

  5. We then use the nisa.nc_matmul operation to perform the matrix multiplication. Note that we set the LHS argument is transposed. Also note that the 64x128 dimension here actually under-utilizes the TensorE, but it helps to distinguish the M, K and N dimensions for education purposes in this first code example.

  6. nisa.nc_matmul always writes its result to PSUM, and since nisa.dma_copy only moves data from SBUF to HBM, we copy the multiplication result from PSUM back to SBUF using nisa.tensor_copy.

We can then execute the kernel and verify correctness against the torch implementation as follows. Note that we use torch.allclose to tolerate numerical error inherent to floating-point arithmetic.

 1device = xm.xla_device()
 2cpu = torch.device('cpu')
 3
 4# Test the small workload with basic kernel
 5lhs_small = torch.rand((64, 128), dtype=torch.bfloat16, device=device)
 6rhs_small = torch.rand((128, 512), dtype=torch.bfloat16, device=device)
 7
 8# Run NKI kernel
 9output_small = nki_matmul_basic_(lhs_small.T, rhs_small)
10
11# Run torch reference
12output_small_torch = torch.matmul(lhs_small, rhs_small)
13
14# Compare results
15print("Checking correctness of nki_matmul_basic")
16if torch.allclose(output_small_torch, output_small, atol=1e-4, rtol=1e-2):
17  print("NKI and Torch match")
18else:
19  print("NKI and Torch differ")

Tiling matrix multiplications#

So far, we’ve limited our matrix multiplication to the tile sizes allowed by NKI’s tile size and layout constraints. Next, we’ll see how to handle larger matrix multiplications. Let’s start with a pseudo-code for tiling an [M,K] @ [K,N] matrix-multiplication. Note that we assume the left-hand-side matrix ([M,K]) is already transposed to LHS_T ([K,M]) for optimal performance of the underlying TensorE.

# LHS_T: left-hand-side matmul argument (shape [K,M])
# RHS: right-hand-side matmul argument (shape [K,N])
# RES: matmul result (shape [M,N])

# Tile LHS_T free dimension
for m in range(0, M, 128):
  # Tile RHS free dimension
  for n in range(0, N, 512):
    # Zero-out the accumulator buffer
    accum = zeros((128, 512))
    # Tile contraction dimension
    for k in range(0, K, 128):
      lhsT_tile = LHS_T[m : m+128, k : k+128]
      rhs_tile = RHS[k : k+128, n : n+512]
      accum += dot(lhsT_tile, rhs_tile)
    RES[m : m+128, n : n+512] = accum

This form of tiling can be achieved in NKI as follows:

 1@nki.jit
 2def nki_matmul_tiled_(lhsT, rhs):
 3  """NKI kernel to compute a matrix multiplication operation in a tiled manner
 4
 5  Args:
 6      lhsT: an input tensor of shape [K,M], where both K and M are multiples for
 7        128.  It is the left-hand-side argument of the matrix multiplication,
 8        delivered transposed for optimal performance.
 9      rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N
10        is a multiple of 512.  It is the right-hand-side argument of the matrix
11        multiplication.
12  Returns:
13      result: the resulting output tensor of shape [M,N]
14  """
15
16  # Verify that the lhsT and rhs have the same contraction dimension.
17  K, M = lhsT.shape
18  K_, N = rhs.shape
19  assert K == K_, "lhsT and rhs must have the same contraction dimension"
20
21  # Lookup the device matrix multiply dimensions.
22  TILE_M = nl.tile_size.gemm_stationary_fmax  # 128
23  TILE_K = nl.tile_size.pmax  # 128
24  TILE_N = nl.tile_size.gemm_moving_fmax  # 512
25
26  # Verify that the input matrices are a multiple of the tile dimensions.
27  assert M % TILE_M == 0, \
28    f"Expected M, {M}, to be a multiple of stationary free-dimension max, {TILE_M}"
29  assert N % TILE_N == 0, \
30    f"Expected N, {N}, to be a multiple of moving free-dimension max, {TILE_N}"
31  assert K % TILE_K == 0, \
32    f"Expected K, {K}, to be a multiple of the partition dimension max, {TILE_K}"
33
34  # Create a space for the result in HBM (not initialized)
35  result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)
36
37  # Use affine_range to loop over tiles
38  for m in nl.affine_range(M // TILE_M):
39    for n in nl.affine_range(N // TILE_N):
40      # Allocate a tensor in PSUM
41      res_psum = nl.ndarray((TILE_M, TILE_N), nl.float32, buffer=nl.psum)
42
43      for k in nl.affine_range(K // TILE_K):
44        # Declare the tiles on SBUF
45        lhsT_tile = nl.ndarray((TILE_K, TILE_M), dtype=lhsT.dtype, buffer=nl.sbuf)
46        rhs_tile = nl.ndarray((TILE_K, TILE_N), dtype=rhs.dtype, buffer=nl.sbuf)
47
48        # Load tiles from lhsT and rhs
49        nisa.dma_copy(dst=lhsT_tile,
50                      src=lhsT[k * TILE_K:(k + 1) * TILE_K,
51                               m * TILE_M:(m + 1) * TILE_M])
52        nisa.dma_copy(dst=rhs_tile, 
53                      src=rhs[k * TILE_K:(k + 1) * TILE_K,
54                              n * TILE_N:(n + 1) * TILE_N])
55
56        # Accumulate partial-sums into PSUM
57        nisa.nc_matmul(dst=res_psum, stationary=lhsT_tile, moving=rhs_tile)
58
59      # Copy the result from PSUM back to SBUF, and cast to expected output data-type
60      res_sb = nl.ndarray(res_psum.shape, dtype=result.dtype, buffer=nl.sbuf)
61      nisa.tensor_copy(dst=res_sb, src=res_psum, dtype=result.dtype)
62
63      # Copy the result from SBUF to HBM.
64      nisa.dma_copy(dst=result[m * TILE_M:(m + 1) * TILE_M,
65                               n * TILE_N:(n + 1) * TILE_N],
66                    src=res_sb)
67
68  return result

A few notes about the above code example:

psum_buf = nl.ndarray(..., buffer=nl.psum)

# condition: an affine range loop
for i in nl.affine_range(N):
   # condition 3: add matmul results from TensorEngine
   nisa.nc_matmul(psum_buf, stationary_tile, moving_tile) # or nl.matmul

The use of PSUM accumulation architecture feature is critical to achieve good performance out of TensorEngine when the contraction dimension of the matmul is greater than 128.

The nl.affine_range is used to define loop-level iterators, which is the recommended iterator type when the loop does not have loop-carried dependency (Note, associative reductions are not considered loop carried dependencies in this context). The first nisa.nc_matmul call overwrites the contents of the psum_buf, with subsequent calls to the nisa.nc_matmul instruction accumulating results into the psum_buf.

There is an alternative way to implement this tiled matrix multiplication kernel using the SPMD programming model. We can use the SPMD model to launch (M/128) x (N/512) instances of the kernel to complete the innermost loop. For more details, refer to the SPMD programming model.

Optimization 1: Removing Redundant Loads#

Currently, every nisa.nc_matmul is accompanied with two nisa.dma_copy calls in the inner loop, both of which move data from HBM to SBUF. Let’s introduce a metric, arithmetic intensity, to help understand why this is problematic. The arithmetic intensity of a workload is defined as the number of computation operations performed per byte of data accessed from HBM on average. The reason why we do not consider data accessed from SBUF in this metric is because the SBUF bandwidth (~20x higher than HBM) is high enough to sustain the peak computation throughput in TensorE.

../../_images/roofline.png

Fig. 79 Roofline Model: The Relationship Between Arithmetic Intensity and Performance#

Fig. 79 shows the roofline model, which models the relationship between arithmetic intensity of a workload and its achievable performance on a given computing platform. To saturate TensorE in a NeuronCore-v2, the arithmetic intensity threshold of a workload is 222 Flops/Byte for bfloat16 data type. Inside the inner loop of nki_matmul_tiled_, accessing lhsT_tile and rhs_tile requires 160 KB of data read from HBM, while the nisa.nc_matmul call involves 16 MFlops. This leads to an arithmetic intensity of 102, which is significantly lower than the saturation threshold of 222. Therefore, nki_matmul_tiled_ operates in the memory bound region of the roofline model and under-utilizes TensorE. To make the best out of TensorE, we need to improve the arithmetic intensity of the matmul kernel.

With NKI, programmers can control when and how to load data from HBM into SBUF and also perform computation. We will demonstrate in the upcoming steps how to increase the arithmetic intensity of the matmul kernel using NKI, thereby maximizing the utilization of TensorE.

First, we notice that in nki_matmul_tiled_, the same tiles from lhsT and rhs matrices are loaded more than once across different iterations of the inner loop. The following example reduces these redundant loads through hoisting them out of the innermost loop.

../../_images/mm-memory-pattern-after-load-hoisting.png

Fig. 80 Memory Pattern After Hoisting Loads Out of the Innermost Loop#

 1@nki.jit
 2def nki_matmul_hoist_load_(lhsT, rhs):
 3  """NKI kernel to compute a matrix multiplication operation in a tiled manner
 4     while hoisting the load of the lhsT and rhs to outer loops.
 5
 6  Args:
 7      lhsT: an input tensor of shape [K,M], where both K and M are multiples for
 8        128.  It is the left-hand-side argument of the matrix multiplication,
 9        delivered transposed for optimal performance.
10      rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N
11        is a multiple of 512.  It is the right-hand-side argument of the matrix
12        multiplication.
13  Returns:
14      result: the resulting output tensor of shape [M,N]
15  """
16
17  # Verify that the lhsT and rhs are the expected sizes.
18  K, M = lhsT.shape
19  K_, N = rhs.shape
20  assert K == K_, "lhsT and rhs must have the same contraction dimension"
21
22  # Lookup the device matrix multiply dimensions.
23  TILE_M = nl.tile_size.gemm_stationary_fmax  # 128
24  TILE_K = nl.tile_size.pmax  # 128
25  TILE_N = nl.tile_size.gemm_moving_fmax  # 512
26
27  # Verify that the input matrices are a multiple of the tile dimensions.
28  assert M % TILE_M == 0, \
29    f"Expected M, {M}, to be a multiple of stationary free-dimension max, {TILE_M}"
30  assert N % TILE_N == 0, \
31    f"Expected N, {N}, to be a multiple of moving free-dimension max, {TILE_N}"
32  assert K % TILE_K == 0, \
33    f"Expected K, {K}, to be a multiple of the partition dimension max, {TILE_K}"
34
35  # Create a space for the result in HBM (not initialized)
36  result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)
37
38  # Use affine_range to loop over tiles
39  for m in nl.affine_range(M // TILE_M):
40    # Load a whole column tiles from lhsT (with K * TILE_M numbers)
41    # This corresponds to the whole row in the original lhs
42    lhsT_tiles = []
43    for k in nl.affine_range(K // TILE_K):
44      # Allocate space in SBUF for the tile (uninitialized)
45      lhsT_tile = nl.ndarray(shape=(TILE_K, TILE_M), dtype=lhsT.dtype, buffer=nl.sbuf)
46      # Copy the tile from HBM to SBUF
47      nisa.dma_copy(dst=lhsT_tile, 
48                    src=lhsT[k * TILE_K:(k + 1) * TILE_K,
49                             m * TILE_M:(m + 1) * TILE_M])
50      # Append the tile to the list of tiles.
51      lhsT_tiles.append(lhsT_tile)
52
53    for n in nl.affine_range(N // TILE_N):
54      # Load a whole column tiles from rhs (with K * TILE_N numbers)
55      rhs_tiles = []
56      for k in nl.affine_range(K // TILE_K):
57        # Allocate space in SBUF for the tile (uninitialized)
58        rhs_tile = nl.ndarray(shape=(TILE_K, TILE_N), dtype=rhs.dtype, buffer=nl.sbuf)
59        # Copy the tile from HBM to SBUF
60        nisa.dma_copy(dst=rhs_tile,
61                      src=rhs[k * TILE_K:(k + 1) * TILE_K,
62                              n * TILE_N:(n + 1) * TILE_N])
63        # Append the tile to the list of tiles.
64        rhs_tiles.append(rhs_tile)
65
66      # Allocate a tile in PSUM for the result (uninitialized)
67      res_psum = nl.ndarray(shape=(TILE_M, TILE_N), dtype=nl.float32, buffer=nl.psum)
68      for k in nl.affine_range(K // TILE_K):
69        # Accumulate partial-sums into PSUM
70        nisa.nc_matmul(dst=res_psum, stationary=lhsT_tiles[k], moving=rhs_tiles[k])
71
72      # Copy the result from PSUM back to SBUF, and cast to expected output data-type
73      res_sb = nl.ndarray(shape=(TILE_M, TILE_N), dtype=nl.float32, buffer=nl.sbuf)
74      nisa.tensor_copy(dst=res_sb, src=res_psum, dtype=result.dtype)
75
76      # Copy the result from SBUF to HBM.
77      nisa.dma_copy(dst=result[m * TILE_M:(m + 1) * TILE_M,
78                               n * TILE_N:(n + 1) * TILE_N],
79                    src=res_sb)
80
81  return result

Optimization 2: Reuse More Load Through Blocking#

While hoisting the load out of the innermost loop eliminates some redundant loads, we can push this further by reordering the computation and the associated memory accesses. The technique we are going to use is called blocking. Blocking explicitly improves temporal locality and reduces memory accesses. It is very similar to the tiling step we did earlier in spirit.

Note that we reserve the word “tile” for defining the granularity of computation and “tiling” for the previous optimization technique that maps the high-level computation onto multiple matrix multiplication instructions executed on the TensorE. TensorE processes a specific “tile size” in a single instruction, leveraging the inherent parallelism in matrix multiplication.

Here, we do blocking, by grouping the work associated with a set of tiles together at another loop nest level. Blocking effectively interleaves a set of compute instructions and loading (DMA) instructions. This optimization does not bring us additional parallelism in computation, but rather improve the arithmetic intensity. This shifts a memory-bound matrix multiplication implementation to a compute-bound one, in order to fully leverage the compute capabilities of TensorE.

Fig. 81 below visualizes the memory pattern after blocking both free dimensions.

../../_images/mm-memory-pattern-after-blocking-free.png

Fig. 81 Memory Pattern After Blocking Free Dimensions#

  1@nki.jit
  2def nki_matmul_block_free_dimension_(lhsT, rhs):
  3  """NKI kernel to compute a matrix multiplication operation while blocking the
  4     free dimensions of the LHS and RHS to improve memory access pattern.
  5
  6  Args:
  7      lhsT: an input tensor of shape [K,M], where both K and M are multiples for
  8        128.  It is the left-hand-side argument of the matrix multiplication,
  9        delivered transposed for optimal performance.
 10      rhs: an input tensor of shape [K,N], where K is a multiple of 128, and N
 11        is a multiple of 512.  It is the right-hand-side argument of the matrix
 12        multiplication.
 13  Returns:
 14      result: the resulting output tensor of shape [M,N]
 15  """
 16
 17  # Verify that the lhsT and rhs have the same contraction dimension.
 18  K, M = lhsT.shape
 19  K_, N = rhs.shape
 20  assert K == K_, "lhsT and rhs must have the same contraction dimension"
 21
 22  # Lookup the device matrix multiply dimensions.
 23  TILE_M = nl.tile_size.gemm_stationary_fmax  # 128
 24  TILE_K = nl.tile_size.pmax  # 128
 25  TILE_N = nl.tile_size.gemm_moving_fmax  # 512
 26
 27  # Configuring the blocking size for the free dimensions
 28  TILES_IN_BLOCK_M = 2
 29  TILES_IN_BLOCK_N = 2
 30
 31  BLOCK_M = TILE_M * TILES_IN_BLOCK_M  # 256
 32  BLOCK_N = TILE_N * TILES_IN_BLOCK_N  # 1024
 33
 34  # the size has to be multiple of block size
 35  assert M % BLOCK_M == 0
 36  assert N % BLOCK_N == 0
 37
 38  # Create a space for the result in HBM (not initialized)
 39  result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)
 40
 41  # Loop over blocks over the M dimension
 42  for m in nl.affine_range(M // BLOCK_M):
 43    # Load TILES_IN_BLOCK_M columns tiles by TILES_K rows from lhsT
 44    lhsT_tiles = []
 45    for bm in nl.affine_range(TILES_IN_BLOCK_M):
 46      # Inner tile array.
 47      lhsT_tiles_internal = []
 48      for k in nl.affine_range(K // TILE_K):
 49        # Allocate space in SBUF for the tile (uninitialized)
 50        lhsT_tile = nl.ndarray(shape=(TILE_K, TILE_M),
 51                               dtype=lhsT.dtype,
 52                               buffer=nl.sbuf)
 53        # Copy the tile from HBM to SBUF
 54        nisa.dma_copy(dst=lhsT_tile,
 55                      src=lhsT[k * TILE_K:(k + 1) * TILE_K,
 56                               (m * TILES_IN_BLOCK_M + bm) *
 57                               TILE_M:((m * TILES_IN_BLOCK_M + bm) + 1) *
 58                               TILE_M])
 59        # Append the tile to the inner list of tiles.
 60        lhsT_tiles_internal.append(lhsT_tile)
 61      # Append the inner list of tiles into the outer list of tiles.
 62      lhsT_tiles.append(lhsT_tiles_internal)
 63
 64    for n in nl.affine_range(N // BLOCK_N):
 65      # Load TILES_IN_BLOCK_N columns from rhs by TILES_K rows from rhs
 66      rhs_tiles = []
 67      for bn in nl.affine_range(TILES_IN_BLOCK_N):
 68        # Inner tile array.
 69        rhs_tiles_internal = []
 70        for k in nl.affine_range(K // TILE_K):
 71          # Allocate space in SBUF for the tile (uninitialized)
 72          rhs_tile = nl.ndarray(shape=(TILE_K, TILE_N),
 73                                dtype=rhs.dtype,
 74                                buffer=nl.sbuf)
 75          # Copy the tile from HBM to SBUF
 76          nisa.dma_copy(dst=rhs_tile,
 77                        src=rhs[k * TILE_K:(k + 1) * TILE_K,
 78                                (n * TILES_IN_BLOCK_N + bn) *
 79                                TILE_N:((n * TILES_IN_BLOCK_N + bn) + 1) *
 80                                TILE_N])
 81          # Append the tile to the inner list of tiles.
 82          rhs_tiles_internal.append(rhs_tile)
 83        # Append the inner list of tiles into the outer list of tiles.
 84        rhs_tiles.append(rhs_tiles_internal)
 85
 86      for bm in nl.affine_range(TILES_IN_BLOCK_M):
 87        for bn in nl.affine_range(TILES_IN_BLOCK_N):
 88          # Allocate a tensor in PSUM
 89          result_tile = nl.ndarray(shape=(TILE_M, TILE_N),
 90                                   dtype=nl.float32,
 91                                   buffer=nl.psum)
 92          for k in nl.affine_range(K // TILE_K):
 93            # Accumulate partial-sums into PSUM
 94            nisa.nc_matmul(dst=result_tile,
 95                           stationary=lhsT_tiles[bm][k],
 96                           moving=rhs_tiles[bn][k])
 97  
 98          # Copy the result from PSUM back to SBUF, and cast to expected
 99          # output data-type
100          result_tmp = nl.ndarray(shape=result_tile.shape,
101                                  dtype=result.dtype,
102                                  buffer=nl.sbuf)
103          nisa.tensor_copy(dst=result_tmp, src=result_tile)
104
105          # Copy the result from SBUF to HBM.
106          nisa.dma_copy(dst=result[(m * TILES_IN_BLOCK_M + bm) *
107                                   TILE_M:((m * TILES_IN_BLOCK_M + bm) + 1) *
108                                   TILE_M,
109                                   (n * TILES_IN_BLOCK_N + bn) *
110                                   TILE_N:((n * TILES_IN_BLOCK_N + bn) + 1) *
111                                   TILE_N],
112                        src=result_tmp)
113
114  return result

Optimization 3: Further Blocking and DMA Efficiency Optimization#

Next, let’s also consider blocking the contraction dimension. Without blocking the contraction dimension, each block of computation leads to the final result of each output block directly, since the input blocks in both lhs_T and rhs cover the entire contraction dimension. After contraction dimension blocking, the accumulation is separated into different groups. We can accumulate the partial sum from each computation block back to an SBUF tensor for the final result. A small amount of HBM traffic might also be introduced if the partial sum cannot be kept in SBUF before being consumed. On the bright side, we can increase the block size for the free dimensions, which continues to improve the arithmetic intensity.

../../_images/mm-memory-pattern-after-blocking-all.png

Fig. 82 Memory Pattern After Blocking All Dimensions#

One final step we can do with NKI is to optimize the layout of the loaded tiles to improve DMA efficiency. This is done by arranging an array of blocks, such that the partition dimension in the nl.ndarray that allocates them remains aligned with the first dimension.

By putting all these optimizations together, we can use NKI to implement optimized matrix multiplication for different sizes. Note that different sizes of input matrices require different optimization plans. The following code optimizes for large matrix multiplication where the free dimensions of both input matrices are multiples of 2048 and the contraction dimension is a multiple of 512.

With the blocking configuration in the code (16 tiles or 2048 numbers in the M dimension; 2 tiles or 1024 numbers in the N dimension; and 8 tiles or 1024 numbers in the K dimension), this computation has an arithmetic intensity of 683 Flops/Byte (2048*1024*1024/(2048*1024 + 1024*1024)). This is certainly above the threshold of 222.

At the same time, this blocking configuration keeps all the tensors within the SBUF limit as much as possible. With all matrices in BF16 data type, the lhsT_tiles requires 4MB and rhs_tiles requires 2MB SBUF memory. The result_tiles requires 4 * NUM_BLOCK_M MB SBUF memory, where NUM_BLOCK_M is M // 2048. Thus, as long as M <= 8192, the required SBUF memory is under the 24 MB budget (4 + 2 + 4 * (8192 // 2048) == 22 MB). When the M dimension becomes bigger, spilling and reloading of the result_tiles will happen, but because the frequency is relatively low, the computation can still be sufficient.

Since the K blocking loop is hand optimized for our ideal data locality, we do not actually want the compiler to rewrite this loop during its vectorization and other loop-level optimization passes. To communicate this we use nl.sequential_range() to construct the K blocking loop.

  1@nki.jit
  2def nki_matmul_fully_optimized_(
  3    lhsT,
  4    rhs,
  5    # Meta-parameters
  6    TILES_IN_BLOCK_M=16,
  7    TILES_IN_BLOCK_N=2,
  8    TILES_IN_BLOCK_K=8,
  9):
 10  """NKI kernel to compute a large matrix multiplication efficiently by
 11     blocking all dimensions and doing layout optimization.
 12
 13  Args:
 14      lhsT: an input tensor of shape [K,M], where K is a multiple of 128 *
 15        TILES_IN_BLOCK_K and M is a multiple of 128 * TILES_IN_BLOCK_M.  It is the
 16        left-hand-side argument of the matrix multiplication, delivered transposed
 17        for optimal performance.
 18      rhs: an input tensor of shape [K,N],  where K is a multiple of 128 *
 19        TILES_IN_BLOCK_K and N is a multiple of 512 * TILES_IN_BLOCK_N.  It is
 20        the right-hand-side argument of the matrix multiplication.
 21      TILES_IN_BLOCK_*: meta parameters to control blocking dimensions
 22  Returns:
 23      result: the resulting output tensor of shape [M,N]
 24  """
 25
 26  # Verify that the lhsT and rhs have the same contraction dimension.
 27  K, M = lhsT.shape
 28  K_, N = rhs.shape
 29  assert K == K_, "lhsT and rhs must have the same contraction dimension"
 30
 31  # Lookup the device matrix multiply dimensions.
 32  TILE_M = nl.tile_size.gemm_stationary_fmax  # 128
 33  TILE_K = nl.tile_size.pmax  # 128
 34  TILE_N = nl.tile_size.gemm_moving_fmax  # 512
 35
 36  # Compute the block dimensions.
 37  BLOCK_M = TILE_M * TILES_IN_BLOCK_M
 38  BLOCK_N = TILE_N * TILES_IN_BLOCK_N
 39  BLOCK_K = TILE_K * TILES_IN_BLOCK_K
 40
 41  # Verify the size is a multiple of block size
 42  assert M % BLOCK_M == 0, \
 43    f"Expected M {M} to be divisble by {BLOCK_M} when there are {TILES_IN_BLOCK_M}"
 44  assert N % BLOCK_N == 0, \
 45    f"Expected N {N} to be divisble by {BLOCK_N} when there are {TILES_IN_BLOCK_N}"
 46  assert K % BLOCK_K == 0, \
 47    f"Expected K {K} to be divisble by {BLOCK_K} when there are {TILES_IN_BLOCK_K}"
 48
 49  # Create a space for the result in HBM (not initialized)
 50  result = nl.ndarray((M, N), dtype=lhsT.dtype, buffer=nl.shared_hbm)
 51
 52  # Compute the number of blocks in each dimension
 53  NUM_BLOCK_M = M // BLOCK_M
 54  NUM_BLOCK_N = N // BLOCK_N
 55  NUM_BLOCK_K = K // BLOCK_K
 56
 57  # Blocking N dimension (the RHS free dimension)
 58  for n in nl.affine_range(NUM_BLOCK_N):
 59    # Create the initial result tiles in SBUF and initialize each tile to
 60    # 0.0, since the final results will be accumulated here. Results in 3-d array.
 61    result_tmps = []
 62    for m_idx in range(NUM_BLOCK_M):
 63      block_m = []
 64      for bm_idx in range(TILES_IN_BLOCK_M):
 65        block_n = []
 66        for bn_idx in range(TILES_IN_BLOCK_N):
 67          # Create the result tile (uninitialized)
 68          tile = nl.ndarray(shape=(TILE_M, TILE_N), dtype=lhsT.dtype, buffer=nl.sbuf)
 69          # Initialize the tile 0.0
 70          nisa.memset(dst=tile, value=0.0)
 71          # Append the tile to block_n array.
 72          block_n.append(tile)
 73        # Append block_n array to block_m array.
 74        block_m.append(block_n)
 75      # Append block_m array into result_tmps.
 76      result_tmps.append(block_m)
 77
 78    # Blocking K dimension (the contraction dimension)
 79    # Use `sequential_range` because we do not want the compiler
 80    # to change this loop by, for example, vectorizing it
 81    for k in nl.sequential_range(NUM_BLOCK_K):
 82      # Loading tiles from rhs
 83      # setting the load tile to `TILE_K x BLOCK_SIZE_N` to optimize DMA performance
 84      rhs_tiles = []
 85      for bk_r in range(TILES_IN_BLOCK_K):
 86        # Allocate rhs_tile tensor, TILE_K x BLOCK_N
 87        rhs_tile = nl.ndarray(shape=(TILE_K, BLOCK_N),
 88                              dtype=rhs.dtype,
 89                              buffer=nl.sbuf)
 90        # Copy block tile from rhs, to rhs_tile.
 91        nisa.dma_copy(dst=rhs_tile[0:TILE_K, 0:BLOCK_N],
 92                      src=rhs[(TILES_IN_BLOCK_K * k + bk_r) *
 93                              TILE_K:(TILES_IN_BLOCK_K * k + bk_r + 1) * TILE_K,
 94                              BLOCK_N * n:BLOCK_N * (n + 1)])
 95        # Append rhs_tile to rhs_tiles.
 96        rhs_tiles.append(rhs_tile)
 97
 98
 99      # Blocking M dimension (the LHS free dimension)
100      for m in nl.affine_range(NUM_BLOCK_M):
101        # Loading tiles from lhsT
102        lhsT_tiles = []
103        for bk_l in nl.affine_range(TILES_IN_BLOCK_K):
104          # Allocate lhsT_tile in SBUF (uninitialized)
105          lhsT_tile = nl.ndarray(shape=(TILE_K, BLOCK_M),
106                                 dtype=lhsT.dtype,
107                                 buffer=nl.sbuf)
108          # Copy block tile from lhsT to lhsT_tile
109          nisa.dma_copy(dst=lhsT_tile[0:TILE_K, 0:BLOCK_M],
110                        src=lhsT[(TILES_IN_BLOCK_K * k + bk_l) *
111                                 TILE_K:(TILES_IN_BLOCK_K * k + bk_l + 1) * TILE_K,
112                                 BLOCK_M * m:BLOCK_M * (m + 1)])
113          # Append to list of lhsT tiles.
114          lhsT_tiles.append(lhsT_tile)
115
116        # Do matmul with all tiles in the blocks
117        for bn in nl.affine_range(TILES_IN_BLOCK_N):
118          for bm in nl.affine_range(TILES_IN_BLOCK_M):
119            # Allocate result_tile in PSUM (uninitialized)
120            result_tile = nl.ndarray(shape=(TILE_M, TILE_N),
121                                     dtype=nl.float32,
122                                     buffer=nl.psum)
123            for bk in nl.affine_range(TILES_IN_BLOCK_K):
124              # Perform matrix multiply on a tile.
125              nisa.nc_matmul(
126                dst=result_tile,
127                stationary=lhsT_tiles[bk][0:TILE_K, bm * TILE_M:(bm + 1) * TILE_M],
128                moving=rhs_tiles[bk][0:TILE_K, bn * TILE_N:(bn + 1) * TILE_N]
129              )
130            # Accumulate the result into the result_tmps tile.
131            nisa.tensor_tensor(dst=result_tmps[m][bm][bn],
132                               data1=result_tmps[m][bm][bn],
133                               data2=result_tile,
134                               op=nl.add)
135
136    # Copying the result from SBUF to HBM
137    for m in nl.affine_range(NUM_BLOCK_M):
138      for bm in nl.affine_range(TILES_IN_BLOCK_M):
139        # coalesce result tiles for better DMA performance
140        result_packed = nl.ndarray(shape=(TILE_M, BLOCK_N),
141                                   dtype=nl.float32,
142                                   buffer=nl.sbuf)
143        for bn in nl.affine_range(TILES_IN_BLOCK_N):
144          nisa.tensor_copy(
145            dst=result_packed[0:TILE_M, bn * TILE_N:(bn + 1) * TILE_N],
146            src=result_tmps[m][bm][bn][0:TILE_M, 0:TILE_N])
147
148        # Copy packed result from SBUF to HBM.
149        nisa.dma_copy(dst=result[(TILES_IN_BLOCK_M * m + bm) *
150                                 TILE_M:(TILES_IN_BLOCK_M * m + bm + 1) * TILE_M,
151                                 BLOCK_N * n:BLOCK_N * (n + 1)],
152                      src=result_packed[0:TILE_M, 0:BLOCK_N])
153
154  return result

Testing Correctness and Benchmarking#

To test the correctness of the kernels, we compare the result with the torch.matmul with torch.allclose.

 1# Test the large workload with tiled kernels
 2lhs = torch.rand((4096, 1024), dtype=torch.bfloat16, device=device)
 3rhs = torch.rand((1024, 2048), dtype=torch.bfloat16, device=device)
 4
 5# Run torch reference
 6output_torch = torch.matmul(lhs, rhs).to(device=cpu)
 7
 8def check_match(nki_func):
 9  output = nki_func(lhs.T, rhs)
10  output_nki = output.to(device=cpu)
11  if torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2):
12    print("NKI and Torch match")
13  else:
14    print("NKI and Torch differ")
15
16print("Checking correctness of nki_matmul_tiled")
17check_match(nki_matmul_tiled_)
18
19print("Checking correctness of nki_matmul_hoist_load")
20check_match(nki_matmul_hoist_load_)
21
22print("Checking correctness of nki_matmul_block_free_dimension")
23check_match(nki_matmul_block_free_dimension_)
24
25print("Checking correctness of nki_matmul_fully_optimized")
26check_match(nki_matmul_fully_optimized_)

Output from the test:

Checking correctness of nki_matmul_tiled
NKI and Torch match
Checking correctness of nki_matmul_hoist_load
NKI and Torch match
Checking correctness of nki_matmul_block_free_dimension
NKI and Torch match
Checking correctness of nki_matmul_fully_optimized
NKI and Torch match

Download All Source Code#

Click the links to download source code of the kernels and the testing code discussed in this tutorial.

You can also view the source code in the GitHub repository nki_samples

Example usage of the scripts:#

Run benchmarking of different NKI kernels:

python3 matrix_multiplication_nki_kernels.py

Run PyTorch implementation to validate the NKI results against the PyTorch implementation:

python3 matrix_multiplication_torch.py