This document is relevant for: Inf2, Trn1, Trn2

NKI Block Dimension Migration Guide#

The SBUF/PSUM tensors in NKI used to allow block dimensions in front of the partition dimension. The block dimension support has been removed due the following reasons.

  • Removing block dimensions does not hurt the expressivity of NKI.

  • Block dimension is a pure software concept and does not have direct hardware mapping.

  • The block dimension is unintuitive and causes confusion.

  • Using block dimension has no inherit performance benefit, particularly using block dimension has no relationship with memory throughput whatsoever.

  • Multi-buffering is implicit with block dimension. Removing block dimension will make multi-buffering more natural.

This document will first explain the semantics of block dimensions in detail, then it will provide information on how to migrate existing code that uses block dimensions while maintain the functional correctness and performance.

What are block dimensions?#

Consider the following NKI tensor.

1a = nl.ndarray((4, 8, nl.par_dim(128), 2, 512), buffer=nl.sbuf)
2
3# - (4, 8): (B) block dimensions
4# - 128: (P) partition dimension
5# - (2, 512): (F) free dimension

As explained in the Direct Allocation Guide, a NKI tensor has three types of dimensions: (B, P, F) . The partition dimension maps to the partition dimension of the physical memory, and the free dimensions describe how data is organized in each SBUF/PSUM partition. The block dimensions described how many physical (P, F) tiles the tensor has.

The block dimension of tensors is a logical dimension and is a pure software concept. The compiler analyzes the memory dependency and allocates physical address to each tiles. This means that the physical tiles may not be alive in the memory simultaneously, and in most of the cases they don not. Consider the following code snippet that access the tensor a.

 1@nki.jit
 2def exp_func(inp):
 3  output = nl.ndarray((4, 8, 128, 2, 512), dtype=float32,
 4    buffer=nl.shared_hbm)
 5  a = nl.ndarray((4, 8, nl.par_dim(128), 2, 512), dtype=float32, buffer=nl.sbuf)
 6  for i in range(4):
 7    for j in range(8):
 8      a[i, j] = nl.load(inp[i, j])
 9      a[i, j] = nl.exp(a[i, j])
10      nl.store(output[i, j], value=result)

At the very minimum, only 1 physical tile of a needs to be alive. Then the execution is completely serialized. Essentially, all physical tiles would have the exact same memory address.

1Physical Address Map
2
3output[0, 0] --> Partition 0 - 128, Free 0 - 2048B
4output[0, 1] --> Partition 0 - 128, Free 0 - 2048B
5...

Instead, compiler could choose to allocate 2 physical tiles to a, then the dma copy from HBM to SBUF can overlap with the exponential operation. In other word, the block dimension allows compiler to perform space-time tradeoff at liberty.

1Physical Address Map
2
3output[0, 0] --> Partition 0 - 128, Free 0    - 2048B
4output[0, 1] --> Partition 0 - 128, Free 2048 - 4096B
5output[0, 2] --> Partition 0 - 128, Free 0    - 2048B
6output[0, 3] --> Partition 0 - 128, Free 2048 - 4096B
7...

When performing the migration, it is important to understand the dependency relationship between blocks and choose the correct migration method accordingly.

Migration for SBUF tensors#

If blocks need to be alive at the same time, move the block dimension into free dimension#

1a = nl.ndarray((8, par_dim(128), 512), buffer=nl.sbuf, dtype=bfloat16)
2
3# ----> Migrate to
4a = nl.ndarray((128, 8, 512), buffer=nl.sbuf, dtype=bfloat16)

As an example, all 8 blocks of add_buf needs to be alive at the same time when the first for loop finishes. Therefore, the block dimension need to be fold into the free dimension.

 1@nki.jit
 2def sb_blocks(inp):
 3    res = nl.ndarray(shape=(8, 128, 512), dtype=inp.dtype, buffer=nl.shared_hbm)
 4    add_buf = nl.ndarray(shape=(8, nl.par_dim(128), 512), dtype=inp.dtype, buffer=nl.sbuf)
 5    for i in range(8):
 6        add_buf[i] = nl.load(inp[i])
 7    for i in range(8):
 8        nl.store(res[i], add_buf[i])
 9    return res
10
11# should migrate to
12@nki.jit
13def sb_blocks_migrated(inp):
14    res = nl.ndarray(shape=(8, 128, 512), dtype=inp.dtype, buffer=nl.shared_hbm)
15    add_buf = nl.ndarray(shape=(128, 8, 512), dtype=inp.dtype, buffer=nl.sbuf)
16    for i in range(8):
17        add_buf[0:128, i, 0:512] = nl.load(inp[i])
18    for i in range(8):
19        nl.store(res[i], add_buf[0:128, i, 0:512])
20    return res

If blocks does not need to be alive at the same time, remove the block dimension and hoist it down#

1a = nl.ndarray((8, par_dim(128), 256))
2for i in nl.affine_range(8):
3  <do something with a[i]>
4
5# should be transformed to ....
6for i in nl.affine_range(8):
7  a = nl.ndarray((128, 256))
8  <do something with a>

As an example, all 8 blocks of add_buf does not need to be alive at the same time. We can remove the block dimension and hoist down the tensor inside the loop.

 1@nki.jit
 2def sb_blocks(inp):
 3    res = nl.ndarray(shape=(8, 128, 512), dtype=inp.dtype, buffer=nl.shared_hbm)
 4    add_buf = nl.ndarray(shape=(8, nl.par_dim(128), 512), dtype=inp.dtype, buffer=nl.sbuf)
 5    for i in range(8):
 6        add_buf[i] = nl.load(inp[i])
 7        nl.store(res[i], add_buf[i])
 8    return res
 9
10# should migrate to
11@nki.jit
12def sb_blocks_migrated(inp):
13    res = nl.ndarray(shape=(8, 128, 512), dtype=inp.dtype, buffer=nl.shared_hbm)
14    for i in range(8):
15        add_buf = nl.ndarray(shape=(128, 512), dtype=inp.dtype, buffer=nl.sbuf)
16        add_buf[0:128, 0:512] = nl.load(inp[i])
17        nl.store(res[i], add_buf[0:128, 0:512])
18    return res

Warning

To preserve performance, it is important to hoist down the tensor inside the loop.

It is important to note that the dependency relationship betweens loop iterations is different in sb_blocks_migrated and the following sb_blocks_migrated_incorrect.

1@nki.jit
2def sb_blocks_migrated_incorrect(inp):
3    res = nl.ndarray(shape=(8, 128, 512), dtype=inp.dtype, buffer=nl.shared_hbm)
4    add_buf = nl.ndarray(shape=(128, 512), dtype=inp.dtype, buffer=nl.sbuf)
5    for i in range(8):
6        add_buf[0:128, 0:512] = nl.load(inp[i])
7        nl.store(res[i], add_buf[0:128, 0:512])
8    return res

In sb_blocks_migrated, compiler could unroll the loop and materialize multiple copies of the tensor add_buf. However, in the sb_blocks_migrated_incorrect, the execution will be serialized because the loop carries dependency on add_buf.

Migration for PSUM tensors#

Note

To be filled, the backend support for removing blocks in PSUM tensor is still in progress.

Migration of direct allocation & multi-buffering#

Note

For more information on direct allocation API, please refer to Direct Allocation Guide

When we have block dimensions, we allocate interleaved address for blocks to achieve multi-buffering.

 1def interleave_alloc_func(idx, pdim_size, fdim_size):
 2  """
 3  This function assumes 1d block dimension, and will allocate unique
 4  address by modulo of 2.
 5
 6  For a tensor of 4 blocks, block 0 and 2 will have the same address, while
 7  block 1 and 3 will have the same address that is different to that of 0 and 2.
 8  """
 9  # unpack the tuple
10  idx, = idx
11
12  # hard-code to partition 0, since each tile takes up 128 partitions
13  start_partition = 0
14
15  return (start_partition, (idx % 2) * fdim_size)
16
17@nki.jit
18def copy_func(inp):
19  output = nl.ndarray((4, 128, 512), dtype=float32, buffer=nl.shared_hbm)
20  a = nl.ndarray((4, nl.par_dim(128), 512), dtype=float32, buffer=ncc.sbuf.alloc(interleave_alloc_func))
21  for i in range(4):
22      a[i] = nl.load(inp[i])
23      nl.store(output[i], value=a[i])

After removing the block dimension, we could write the following to implement the same multi-buffering, which is actually more natural and closer to that on CPU.

 1def interleave_alloc_func(idx, pdim_size, fdim_size):
 2  """
 3  This function assumes 1d block dimension, and will allocate unique
 4  address by modulo of 2.
 5
 6  For a tensor of 4 blocks, block 0 and 2 will have the same address, while
 7  block 1 and 3 will have the same address that is different to that of 0 and 2.
 8  """
 9  # unpack the tuple
10  assert idx == () # We don't have any block dimension
11
12  # hard-code to partition 0, since each tile takes up 128 partitions
13  start_partition = 0
14
15  return (start_partition, (idx % 2) * fdim_size)
16
17@nki.compiler.skip_middle_end_transformations
18@nki.jit
19def exp_func(inp):
20  output = nl.ndarray((4, 128, 512), dtype=nl.float32, buffer=nl.shared_hbm)
21  a = nl.ndarray((128, 2, 512), dtype=nl.float32, buffer=ncc.sbuf.alloc(interleave_alloc_func))
22  for i in range(4):
23    a[0:128, i % 2, 0:512] = nl.load(inp[i])
24    nl.store(output[i], value=a[0:128, i % 2, 0:512])

This document is relevant for: Inf2, Trn1, Trn2