This document is relevant for: Inf2, Trn1, Trn1n

NKI Programming Model#

The NKI programming model enables developers to create custom kernels to program NeuronCores, where every kernel consists of three main stages:

  1. Loading of inputs from device memory (High Bandwidth Memory, or HBM) to the on-chip SRAM (State Buffer, or SBUF).

  2. Computation definition, to be executed on the NeuronCore compute engines.

  3. Storing of outputs from on-chip SRAM (SBUF) back to device memory (HBM).

Fig. 8 below is a simplified diagram of a NeuronCore along with its attached HBM device memory. NKI kernels currently target a single NeuronCore-v2.

As shown in Fig. 8, a single NeuronCore consists of two on-chip SRAMs (SBUF and PSUM) and four heterogenous compute engines: the tensor engine, vector engine, scalar engine, and GPSIMD engine. For more information about the compute engine capabilities, see NeuronCore-v2 Architecture. Next, let’s dive into the memory hierarchy design of NeuronCore-v2, which provides the necessary architecture knowledge to understand the NKI programming model.

../../_images/pm-nc.png

Fig. 8 NeuronCore Architecture (multiple NeuronCores available per NeuronDevice)#

Memory hierarchy#

Fig. 9 below shows the four-level memory hierarchy available to a single NeuronCore. The ranges provided in the figure are intended to calibrate the programmer’s mental model. See Neuron Architecture for the exact values.

Similar to standard memory hierarchy in other devices, memories near the top of the hierarchy are the closest to the compute engines; therefore, they are designed to provide the highest bandwidth and lowest latency. However, the faster memories have smaller capacities compared to memories near the bottom. Unlike memory hierarchy for traditional processors (e.g., CPU, GPU), all the memories available to a NeuronCore are software-managed. They are managed either directly by the programmers or the Neuron SDK. In other words, NeuronCore does not have a hardware cache system to perform any data movement across memories that is opaque to the program. Next, let’s discuss the different memories bottom-up.

../../_images/pm-memory.png

Fig. 9 NeuronCore Memory Hierarchy with Capacity and Bandwidth Ranges#

NeuronCore external memory#

The two memories at the bottom of the hierarchy, host memory and device memory, are both considered external memory for a NeuronCore. These memories are linear memory, where multi-dimensional tensors must be stored in a flattened manner.

The host memory is the CPU-attached DRAM, which is accessible by the host CPUs and all the NeuronCores attached to the instance. NKI kernels currently do not provide APIs to move data in and out of the host memory directly, but we can rely on ML frameworks such as PyTorch or JAX to send input data from host memory into NeuronDevice and vice versa. For an example of this, see Getting Started with NKI.

The device memory resides within a NeuronDevice and uses High Bandwidth Memory (HBM) technologies starting from NeuronDevice v2. This means that device memory and HBM refer to the same thing within NKI. Currently, the input and output parameters to NKI kernels must be HBM tensor references. Input tensors in HBM must be loaded into memory within a NeuronCore before any computation can take place.

NeuronCore internal memory#

The two memories at the top of the hierarchy, SBUF and PSUM, are both considered internal, on-chip memory for a NeuronCore. Both memories are two-dimensional memory, organized in 128 partitions. The partitions size of PSUM is typically much smaller than SBUF, and PSUM/SBUF partition sizes vary with NeuronCore generations.

State Buffer (SBUF) memory is the main software-managed on-chip SRAM. The SBUF is accessible by all the compute engines within a NeuronCore. NKI kernel input tensors from HBM must be loaded into the SBUF for computation using nki.language.load, and computed output tensors of the kernel must be stored back into the HBM from SBUF using nki.language.store before the host can access them. In addition, SBUF is used for storing intermediate data within the kernel, generated by the compute engines. Note, SBUF has ~20x higher bandwidth than HBM, but needs to be carefully managed to minimize HBM accesses for better performance.

Lastly, Partial Sum Buffer (PSUM) memory is a small, dedicated memory designed for storing matrix multiplication (MatMult) results computed by the tensor engine. Tensor Engine is able to read-add-write to every address in PSUM. Therefore, PSUM is useful for performing large MatMult calculations using multiple tiles where multiple MatMult instructions need to accumulate into the same output tile. As is shown in Fig. 8, PSUM memory can also be read and written by the vector and scalar engines. However, due to the limited capacity of PSUM, we recommend that you reserve PSUM space for the tensor engine to write MatMult outputs and to use the vector and scalar engines to evict MatMult results back to SBUF as soon as possible.

Note that to optimize kernel performance, it is a good practice for NKI programmers to be mindful of SBUF and PSUM usage through careful tiling and loop fusion. However, ultimately the Neuron compiler performs memory allocation for SBUF and PSUM and assigns memory addresses to kernel intermediate data. When the cumulative size of live data defined by the NKI kernel overflows the capacity of any on-chip memory, the Neuron compiler inserts the necessary spills or refills between that memory and the next-tier memory in the hierarchy.

Representing data in NKI#

NKI represents data in NeuronCore’s memory hierarchy with built-in type Tensor and its subclasses.

A Tensor is a multi-dimensional array which contains elements with the same data type. Programmers can pass Tensor in and out of NKI kernels, and declare or initialize Tensor in any memory within the NeuronDevice (PSUM, SBUF, HBM) using APIs such as nki.language.ndarray, nki.language.zeros, and nki.language.full. Input and output tensors from ML frameworks to NKI kernels can be reinterpreted as NKI Tensor of hbm buffer type in the same underlying memory buffer.

Tensor in NeuronCore’s internal memories (SBUF and PSUM) also have a dimension mapped to the partitions of the internal memories. We call this dimension the partition dimension. By default, NKI infers the first dimension (that is, the left most dimension) as the partition dimension of Tensor. Users could also explicitly annotate the partition dimension with par_dim from nki.language. For example:

# NKI infers the left most dimension as the partition dimension (size 128 below)
x = nl.ndarray((128, 32, 512), dtype=nl.float32, buffer=nl.sbuf)

# Same as above but more verbose
y = nl.ndarray((nl.par_dim(128), 32, 512), dtype=nl.float32, buffer=nl.sbuf)

# We can also explicitly annotate the partition dimension if we want the partition dimension
# to be on the other dimensions. In the following code we are creating a tensor whose partition
# dimension is the second from the left most dimension
z = nl.ndarray((128, nl.par_dim(32), 512), dtype=nl.float32, buffer=nl.sbuf)

There is a special subclass of Tensor called Index. Index represents the result of the affine expression over variables produced by index-generating APIs, such as loop variables, nki.language.program_id, nki.language.affine_range, and nki.language.arange.

A Tensor whose partition dimension is the first dimension is also called a Tile in NKI. In the above code example, x and y is a Tile, z is not a Tile. All NKI APIs take Tile as input and return a Tile as output. We will give more explanation in Tile-based operations.

Tile-based operations#

All NKI APIs operate on Tile, which aligns with NeuronCore instruction set architecture (NeuronCore ISA).

x = nl.ndarray((128, 32, 512), dtype=nl.float32, buffer=nl.sbuf)
xx = nl.exp(x) # works

z = nl.ndarray((128, nl.par_dim(32), 512), dtype=nl.float32, buffer=nl.sbuf)
zz = nl.exp(z) # not supported

To call NKI APIs to process data in a Tensor whose partition dimension is not the first dimension, users need to generate Tiles from the Tensor. This can be done by indexing the Tensor with a tuple of Index, following standard Python syntax Tensor[Index, Index, ...]. For example:

z = nl.ndarray((128, nl.par_dim(32), 512), dtype=nl.float32, buffer=nl.sbuf)
for i in range(128):
  zz = nl.exp(z[i, :, :]) # works

We will provide more discussion of the indexing in Tensor Indexing. Next, let’s discuss two important considerations when working with tile-based operations in NKI: data layout and tile size constraints.

Layout considerations#

When working with multi-dimensional arrays in any platform, it is important to consider the physical memory layout of the arrays, or how data is stored in memory. For example, in the context of 1D linear memory, we can store a 2D array in a row-major layout or a column-major layout. Row-major layouts place elements within each row in contiguous memory, and column-major layouts place elements within each column in contiguous memory.

As discussed in the Memory hierarchy section, the on-chip memories, SBUF and PSUM, are arranged as 2D memory arrays. The first dimension is the partition dimension P with 128 memory partitions that can be read and written in parallel by compute engines. The second dimension is the free dimension F where elements are read and written sequentially. A tensor is placed in SBUF and PSUM across both P and F, with the same start offset across all P partitions used by the tensor. Fig. 10 below illustrates a default tensor layout. Note that a tile in NKI must map shape[0] to the partition dimension.

../../_images/pm-layout.png

Fig. 10 Tensor mapped to partition and free dimensions of SBUF and PSUM#

Similar to other domain-specific languages that operate on tensors, NKI defines a contraction axis of a tensor as the axis over which reduction is performed, for example the summation axis in a dot product. NKI also defines a parallel axis as an axis over which the same operation is performed on all elements. For example, if we take a [100, 200] matrix and sum each row independently to get an output of shape [100, 1], then the row-axis (axis[0], left-most) is the parallel axis, and the column-axis (axis[1], right-most) is the contraction axis.

To summarize, the partition and free dimensions of a NKI tensor dictate how the tensor is stored in the 2D on-chip memories physically, while the parallel and contraction axes of a tensor are logical axes that are determined by the computation to be done on the tensor.

The NeuronCore compute engines impose two layout constraints:

  • [LC#1] For matrix multiplication operations, the contraction axis of both input tiles must be mapped to the P dimension.

  • [LC#2] For operations that are not matrix multiplication operations, such as scalar or vector operations, the parallel axis should be mapped to the P dimension.

LC#1 means that to perform a matrix multiplication of shapes [M, K] and [K, N], Tensor Engine (the engine performing this operation) requires the K dimension to be mapped to the partition dimension in SBUF for both input matrices. Therefore, you need to pass shapes [K, M] and [K, N] into the nki.isa.nc_matmul API, as the partition dimension is always the left-most dimension for an input tile to any NKI compute API.

To help developers get started with NKI quickly, NKI also provides a high-level API nki.language.matmul that can take [M, K] and [K, N] input shapes and invoke the necessary layout shuffling on the input data before sending it to the Tensor Engine matmul instruction.

LC#2, on the other hand, is applicable to many instructions supported on Vector, Scalar and GpSimd Engines. See nki.isa.tensor_reduce API as an example.

Tile size considerations#

Besides layout constraints, NeuronCore hardware further imposes three tile-size constraints in NKI:

  • [TC#1] The P dimension size of a tile in both SBUF and PSUM must never exceed nki.tile_size.pmax == 128.

  • [TC#2] For tiles in PSUM, the F dimension size must not exceed nki.tile_size.psum_fmax == 512.

  • [TC#3] Matrix multiplication input tiles F dimension size must not exceed nki.tile_size.gemm_stationary_fmax == 128 on the left-hand side (LHS), or nki.tile_size.gemm_moving_fmax == 512 on the right-hand side (RHS).

You are responsible for breaking your tensors according to these tile-size constraints. If the constraints are not met properly, the NKI kernel compilation throws a SyntaxError indicating which constraint is violated. For example, below we show a simple kernel that applies the exponential function to every element of an input tensor. To start, let’s write a kernel that expects a hard-coded shape of (128, 512) for both input and output tensors:

 1import neuronxcc.nki.language as nl
 2from torch_neuronx import nki_jit
 3
 4@nki_jit
 5def tensor_exp_kernel_(in_tensor, out_tensor):
 6  """NKI kernel to compute elementwise exponential of an input tensor
 7
 8  Args:
 9      in_tensor: an input tensor of shape [128,512]
10      out_tensor: an output tensor of shape [128,512]
11  """
12  # Generate indices for the input/output tensors
13  i_p = nl.arange(128)[:, None]
14  i_f = nl.arange(512)[None, :]
15
16  # Load input data from HBM to on-chip memory
17  in_tile = nl.load(in_tensor[i_p, i_f])
18
19  # perform the computation:
20  out_tile = nl.exp(in_tile)
21
22  # store the results back to HBM
23  nl.store(out_tensor[i_p, i_f], value=out_tile)
24
25
26if __name__ == "__main__":
27  import torch
28  from torch_xla.core import xla_model as xm
29
30  device = xm.xla_device()
31
32  shape = (128, 512)
33  in_tensor = torch.ones(shape,  dtype=torch.bfloat16).to(device=device)
34  out_tensor = torch.zeros(shape, dtype=torch.bfloat16).to(device=device)
35  tensor_exp_kernel_(in_tensor, out_tensor)
36
37  print(out_tensor) # an implicit XLA barrier/mark-step

As expected, the output tensor is an element-wise exponentiation of the input-tensor (a tensor of ones):

tensor([[2.7188, 2.7188, 2.7188, ..., 2.7188, 2.7188, 2.7188],
        [2.7188, 2.7188, 2.7188, ..., 2.7188, 2.7188, 2.7188],
        [2.7188, 2.7188, 2.7188, ..., 2.7188, 2.7188, 2.7188],
        ...,
        [2.7188, 2.7188, 2.7188, ..., 2.7188, 2.7188, 2.7188],
        [2.7188, 2.7188, 2.7188, ..., 2.7188, 2.7188, 2.7188],
        [2.7188, 2.7188, 2.7188, ..., 2.7188, 2.7188, 2.7188]],
        device='xla:1', dtype=torch.bfloat16)

Now let’s examine what happens if the input/output tensor shapes do not match the shape of the compute kernel. As an example, we can change the input and output tensor shape from [128,512] to [256,512]:

 1if __name__ == "__main__":
 2  import torch
 3  from torch_xla.core import xla_model as xm
 4
 5  device = xm.xla_device()
 6
 7  shape = (256, 512) # Previously (128, 512)
 8  in_tensor = torch.ones(shape,  dtype=torch.bfloat16).to(device=device)
 9  out_tensor = torch.zeros(shape, dtype=torch.bfloat16).to(device=device)
10  tensor_exp_kernel_(in_tensor, out_tensor)
11
12  print(out_tensor) # an implicit XLA barrier/mark-step

Since the compute kernel is expecting (128, 512) input/output tensors, but we used a (256, 512) input/output tensor instead, the bottom half of the output tensor becomes garbage data:

tensor([[2.7188, 2.7188, 2.7188, ..., 2.7188, 2.7188, 2.7188],
        [2.7188, 2.7188, 2.7188, ..., 2.7188, 2.7188, 2.7188],
        [2.7188, 2.7188, 2.7188, ..., 2.7188, 2.7188, 2.7188],
        ...,
        [0.5273, 0.6055, 0.4336, ..., 0.9648, 0.9414, 0.4062],
        [0.7109, 0.2539, 0.7227, ..., 0.7344, 0.2539, 0.1211],
        [0.8867, 0.2109, 0.8789, ..., 0.8477, 0.2227, 0.1406]],
        device='xla:1', dtype=torch.bfloat16)

We could try to fix this by changing the tile size inside the compute kernel to (256, 512) as well, and see what happens: (NOTE: This violates tile-size constraint #1!):

 1import neuronxcc.nki.language as nl
 2from torch_neuronx import nki_jit
 3
 4@nki_jit
 5def tensor_exp_kernel_(in_tensor, out_tensor):
 6  """NKI kernel to compute elementwise exponential of an input tensor
 7
 8  Args:
 9      in_tensor: an input tensor of shape [128,512]
10      out_tensor: an output tensor of shape [128,512]
11  """
12  # Generate indices for the input/output tensors
13  i_p = nl.arange(256)[:, None] # Previously nl.arange(128)
14  i_f = nl.arange(512)[None, :]
15
16  # Load input data from HBM to on-chip memory
17  in_tile = nl.load(in_tensor[i_p, i_f])
18
19  # perform the computation:
20  out_tile = nl.exp(in_tile)
21
22  # store the results back to HBM
23  nl.store(out_tensor[i_p, i_f], value=out_tile)
24
25
26if __name__ == "__main__":
27  import torch
28  from torch_xla.core import xla_model as xm
29
30  device = xm.xla_device()
31
32  shape = (256, 512) # Previously (128, 512)
33  in_tensor = torch.ones(shape,  dtype=torch.bfloat16).to(device=device)
34  out_tensor = torch.zeros(shape, dtype=torch.bfloat16).to(device=device)
35  tensor_exp_kernel_(in_tensor, out_tensor)
36
37  print(out_tensor) # an implicit XLA barrier/mark-step

Here, Neuron compiler identifies the tile-size constraint violation and fails compilation with the following exception:

SyntaxError: Size of partition dimension 256 exceeds architecture limitation of 128.

Now, let’s see how NKI developers can build a kernel that properly handles (256, 512) input/output tensors with a simple loop. We can use the nki.language.tile_size.pmax constant defined in NKI as the maximum partition dimension size in a tile.

 1import neuronxcc.nki.language as nl
 2from torch_neuronx import nki_jit
 3
 4@nki_jit
 5def tensor_exp_kernel_(in_tensor, out_tensor):
 6  """NKI kernel to compute elementwise exponential of an input tensor
 7
 8  Args:
 9      in_tensor: an input tensor of shape [256,512]
10      out_tensor: an output tensor of shape [256,512]
11  """
12  i_f = nl.arange(512)[None, :]
13
14  for k in nl.affine_range(2):
15    # Generate tensor indices for the input/output tensors
16    i_p = k * nl.tile_size.pmax + nl.arange(nl.tile_size.pmax)[:, None]
17
18    # Load input data from HBM to on-chip memory
19    in_tile = nl.load(in_tensor[i_p, i_f])
20
21    # perform the computation
22    out_tile = nl.exp(in_tile)
23
24    # store the results back to HBM
25    nl.store(out_tensor[i_p, i_f], value=out_tile)

The nl.affine_range(2) API call returns a list of integers [0, 1]. nl.affine_range should be the default loop iterator choice in NKI, when the loop has no loop-carried dependency. Note, associative reductions are not considered loop carried dependencies in this context. One such example is accumulating results of multiple matrix multiplication calls into the same output buffer using += (see Matmul Tutorial for an example). Otherwise, nl.sequential_range should be used to handle loop-carried dependency. Note, Neuron compiler transforms any usage of Python range() API into nl.sequential_range() under the hood. See NKI iterator API for a detailed discussion of various loop iterator options in NKI.

While the code above does handle (256, 512) tensors correctly, it is rather inflexible since it only supports input shape of (256, 512). Therefore, as a last step, we extend this kernel to handle varying input/output sizes:

 1import neuronxcc.nki.language as nl
 2from torch_neuronx import nki_jit
 3import math
 4
 5@nki_jit
 6def tensor_exp_kernel_(in_tensor, out_tensor):
 7  """NKI kernel to compute elementwise exponential of an input tensor
 8
 9  Args:
10      in_tensor: an input tensor of ANY 2D shape (up to SBUF size)
11      out_tensor: an output tensor of ANY 2D shape (up to SBUF size)
12  """
13  sz_p, sz_f = in_tensor.shape
14
15  i_f = nl.arange(sz_f)[None, :]
16
17  for p in nl.affine_range(math.ceil(sz_p / nl.tile_size.pmax)):
18    # Generate tensor indices for the input/output tensors
19    # pad index to pmax, for simplicity
20    i_p = p * nl.tile_size.pmax + nl.arange(nl.tile_size.pmax)[:, None]
21
22    # Load input data from external memory to on-chip memory
23    # only read up to sz_p
24    in_tile = nl.load(in_tensor[i_p, i_f], mask=(i_p<sz_p))
25
26    # perform the computation
27    out_tile = nl.exp(in_tile)
28
29    # store the results back to external memory
30    # only write up to sz_p
31    nl.store(out_tensor[i_p, i_f], value=out_tile, mask=(i_p<sz_p))

The above example handles cases where in_tensor.shape[0] is not a multiple of 128 by passing a mask field into the nl.load and nl.store API calls. For more information, refer to NKI API Masking.

Later in this guide, we’ll explore another way to launch a kernel with varying input/output shapes, with a single program multiple data programming model, or SPMD. The SPMD programming model removes the need for explicit looping over different tiles with variable trip counts, which could lead to cleaner and more readable code.

Tensor Indexing#

As mentioned above, we can index Tensor with standard Python syntax to produce Tiles. There are two styles of indexing: Basic and Advanced Tensor Indexing. Note that currently NKI does not support mixing Basic and Advanced Tensor Indexing in the same Index tuple.

Basic Tensor Indexing#

We can index a Tensor with fewer indices than dimensions, we get a view of the original tensor as a sub-dimensional tensor. For example:

x = nl.ndarray((2, 2, 2), dtype=nl.float32, buffer=nl.hbm)

# `x[1]` return a view of x with shape of [2, 2]
# [[x[1, 0, 0], x[1, 0 ,1]], [x[1, 1, 0], x[1, 1 ,1]]]
assert x[1].shape == [2, 2]

By indexing a Tensor like this, we can generate a Tile with the partition dimension in the first dimension and feed the Tile to NKI compute APIs:

# Not a tile, cannot directly feed to a NKI compute API
x = nl.ndarray((2, nl.par_dim(2), 2), dtype=nl.float32)
# Error
y = nl.exp(x)

# `x[1]` have shape [2, 2], and the first dimension is the partition dimension of the original
# tensor. We can feed it to a NKI compute API.
y = nl.exp(x[1])

NKI also supports slicing in basic tensor indexing:

x = nl.ndarray((2, 128, 1024), dtype=nl.float32, buffer=nl.hbm)

# `x[1, :, :]` is the same as `x[1]`
assert x[1, :, :].shape == [128, 1024]

# Get a smaller view of the third dimension
assert x[1, :, 0:512].shape == [128, 512]

# `x[:, 1, 0:2]` returns a view of x with shape of [2, 2]
# [[x[0, 1, 0], x[0, 1 ,1]], [x[1, 1, 0], x[1, 1 ,1]]]
assert x[:, 1, 0:2].shape == [2, 2]

Advanced Tensor Indexing#

So far we have only shown basic indexing in tensors. However, NeuronCore offers much more flexible tensorized memory access in its on-chip SRAMs along the free dimension. You can use this to efficiently stride the SBUF/PSUM memories at high performance for all NKI APIs that access on-chip memories. However, such flexible indexing is not supported along the partition dimension. That being said, device memory (HBM) is always more performant when accessed sequentially.

In this section, we share several use cases that benefit from advanced memory access patterns and demonstrate how to implement them in NKI.

Advanced Tensor Indexing in NKI leverages the nl.arange API.

Case #1 - Tensor split to even and odd columns#

Here we split an input tensor into two output tensors, where the first output tensor gathers all the even columns from the input tensor, and the second output tensor gathers all the odd columns from the input tensor. We assume the rows of the input tensors are mapped to SBUF partitions. Therefore, we are effectively gathering elements along the free dimension of the input tensor. Fig. 11 below visualizes the input and output tensors.

../../_images/pm-index-1.png

Fig. 11 Tensor split to even and odd columns#

 1import neuronxcc.nki.language as nl
 2from torch_neuronx import nki_jit
 3import math
 4
 5@nki_jit
 6def tensor_split_kernel_(in_tensor, out_tensor_even, out_tensor_odd):
 7  """NKI kernel to split an input tensor into two output tensors, along the column axis.
 8
 9  The even columns of the input tensor will be gathered into the first output tensor,
10  and the odd columns of the input tensor will be gathered into the second output tensor.
11
12  Args:
13      in_tensor: an input tensor
14      out_tensor_even: a first output tensor (will hold the even columns of the input tensor)
15      out_tensor_odd: a second output tensor (will hold the odd columns of the input tensor)
16  """
17
18  # Extract tile sizes.
19  sz_p, sz_f = in_tensor.shape
20  sz_fout_even, sz_fout_odd = out_tensor_even.shape[1], out_tensor_odd.shape[1]
21
22  # We assume that all three tensors have the same partition dimension size
23  # and it does not exceed pmax
24  assert in_tensor.shape[0] == out_tensor_even.shape[0] == out_tensor_odd.shape[0]
25  assert in_tensor.shape[0] <= nl.tile_size.pmax
26
27  # Make sure even/odd output tensors have correct free dimension size
28  assert sz_fout_even == math.ceil(sz_f / 2)
29  assert sz_fout_odd == math.floor(sz_f / 2)
30
31  # Generate tensor indices for the input/output tensors
32  i_p = nl.arange(sz_p)[:, None]
33  i_f = nl.arange(sz_f)[None, :]
34  i_fout_even = nl.arange(sz_fout_even)[None, :]
35  i_fout_odd = nl.arange(sz_fout_odd)[None, :]
36
37  # Split pattern:
38  i_f_even = (2 * i_fout_even)
39  i_f_odd = (2 * i_fout_odd + 1)
40
41  # Load input data from external memory to on-chip memory
42  in_tile = nl.load(in_tensor[i_p, i_f])
43
44  # Perform the split
45  # these assignments invoke copy instructions under the hood
46  # which can execute on either Scalar or Vector Engine
47  # (decided by compiler instruction scheduler)
48  out_tile_even = in_tile[i_p, i_f_even]
49  out_tile_odd = in_tile[i_p, i_f_odd]
50
51  # Store the results back to external memory
52  nl.store(out_tensor_even[i_p, i_fout_even], value=out_tile_even)
53  nl.store(out_tensor_odd[i_p, i_fout_odd], value=out_tile_odd)
54
55
56if __name__ == "__main__":
57    import torch
58    from torch_xla.core import xla_model as xm
59
60    device = xm.xla_device()
61
62    X, Y = 4, 5
63    in_tensor = torch.arange(X * Y, dtype=torch.bfloat16).reshape(X, Y).to(device=device)
64
65    out1_tensor = torch.zeros((X, Y-Y//2), dtype=torch.bfloat16).to(device=device)
66    out2_tensor = torch.zeros((X, Y//2), dtype=torch.bfloat16).to(device=device)
67
68    tensor_split_kernel_(in_tensor, out1_tensor, out2_tensor)
69    print(in_tensor, out1_tensor, out2_tensor)

The main concept in this example is that we introduced the even (i_f_even) and odd ( i_f_odd ) indices. Note that both indices are affine expressions of the form start + stride * nl.arange(size) with a specific start offset (0/1 respectively) and stride (2 for both cases). This allows us to stride through the in_tile memory and copy it to both output tiles (out_tile_even and out_tile_odd), according to the desired pattern.

Case #2 - Transpose tensor along the f axis#

In this example we transpose a tensor along two of its axes. Note, there are two main types of transposition in NKI:

  1. Transpose between the partition-dimension axis and one of the free-dimension axes, which is achieved via the nki.isa.nc_transpose API.

  2. Transpose between two free-dimension axes, which is achieved via a nki.language.copy API, with indexing manipulation in the transposed axes to re-arrange the data.

In this example, we’ll focus on the second case: consider a three-dimensional input tensor [P, F1, F2], where the P axis is mapped to the different SBUF partitions and the F1 and F2 axes are flattened and placed in each partition, with F1 being the major dimension. Our goal in this example is to transpose the F1 and F2 axes with a parallel dimension P, which would re-arrange the data within each partition. Fig. 12 below illustrates the input and output tensor layouts.

../../_images/pm-index-2.png

Fig. 12 Tensor F1:F2 Transpose#

 1import neuronxcc.nki as nki
 2import neuronxcc.nki.language as nl
 3
 4
 5def tensor_transpose2D_kernel_(in_tensor, out_tensor, shape2D):
 6  """
 7  NKI kernel to reorder the elements on axis[1] of the input tensor.
 8
 9  Every row of the input tensor is a flattened row-major 2D matrix.
10  The shape2D argument defines the dimensions of the flattened matrices (#rows,#cols).
11  Our goal in this kernel is to transpose these flattened 2D matrices, i.e. make them (#cols,#rows).
12
13  Example:
14      in_tensor = [a0,a1,a2,a3,b0,b1,b2,b3,c0,c1,c2,c3]
15      shape2D = (3,4)
16  this means that in_tensor has 3 rows and 4 columns, i.e. can be represented as:
17      [a0,a1,a2,a3]
18      [b0,b1,b2,b3]
19      [c0,c1,c2,c3]
20  after transpose, we expect to get:
21      [a0,b0,c0]
22      [a1,b1,c1]
23      [a2,b2,c2]
24      [a3,b3,c3]
25  Thus, out_tensor is expected to be [a0,b0,c0,a1,b1,c1,a2,b2,c2,a3,b3,c3]
26
27  Args:
28    in_tensor: an input tensor
29    shape2D: tuple representing the dimensions to be transposed: (#rows, #cols)
30    out_tensor: an output (transposed) tensor
31  """
32  # Gather input shapes
33  sz_p, _ = in_tensor.shape
34
35  # Load input data from external memory to on-chip memory
36  in_tile = nl.load(in_tensor)
37
38  # Performing f1/f2 transpose
39  # ==========================
40  # The desired transpose pattern is provided as an input:
41  sz_f1, sz_f2 = shape2D
42
43  # We're going to need 3 indices to perform f1:f2 transpose.
44  # - i_p0 is the parallel index
45  # - i_f1 and i_f2 are both free-dim indices, and will be used to transpose between the f1/f2 axes
46  i_p0 = nl.arange(sz_p)[:, None, None]
47  i_f1 = nl.arange(sz_f1)[None, :, None]
48  i_f2 = nl.arange(sz_f2)[None, None, :]
49
50  # Perform the transposition via a SBUF-to-SBUF copy, with access-pattern manipulation
51  # Note that we have 2D tensors and 3 indices, since we need to represent a 2D access pattern *per partition*
52  # RHS traverses an F1 x F2 matrix in a row major manner
53  # LHS traverses an F2 x F1 (new) matrix in a row major manner
54  out_tile = nl.ndarray(shape=(sz_p, sz_f2*sz_f1), dtype=out_tensor.dtype)
55  out_tile[i_p0, i_f2*sz_f1+i_f1] = nl.copy(in_tile[i_p0, i_f1*sz_f2+i_f2])
56
57  # Finally, we store out_tile to external memory
58  nl.store(out_tensor, value=out_tile)

The main concept introduced in this example is a 2D memory access pattern per partition, via additional indices. We copy in_tile into out_tile, while traversing the memory in different access patterns between the source and destination, thus achieving the desired transposition.

You may download the full runnable script from Transpose2d tutorial.

Case #3 - 2D pooling operation#

Lastly, we examine a case of dimensionality reduction. We implement a 2D MaxPool operation, which is used in many vision neural networks. This operation takes C x [H,W] matrices and reduces each matrix along the H and W axes. To leverage free-dimension flexible indexing, we can map the C (parallel) axis to the P dimension and H/W (contraction) axes to the F dimension. Performing such a 2D pooling operation requires a 4D memory access pattern in the F dimension, with reduction along two axes. Fig. 13 below illustrates the input and output tensor layouts.

../../_images/pm-index-3.png

Fig. 13 2D-Pooling Operation (reducing on axes F2 and F4)#

 1import neuronxcc.nki.language as nl
 2from torch_neuronx import nki_jit
 3
 4@nki_jit
 5def tensor_maxpool_kernel_(in_tensor, out_tensor, pool_size):
 6  """NKI kernel to compute a 2D max-pool operation
 7
 8  Args:
 9      in_tensor: an input tensor, of dimensions C x H x W
10      pool_size: integer P representing a (square) pool-window size
11      out_tensor: the resulting output tensor, of dimensions C x (H/P) x (W/P)
12  """
13
14  # Get input/output dimensions
15  sz_cin, sz_hin, sz_win = in_tensor.shape
16  sz_cout, sz_hout, sz_wout = out_tensor.shape
17  assert sz_cin == sz_cout
18
19  # Set relevant sizes
20  sz_p = sz_cin
21  sz_pool = pool_size
22
23  # Generate tensor h/w index patterns
24  # 3D indexing according to [C, H, W]
25  i_p = nl.arange(sz_p)[:, None, None] # 3D for
26  i_win = nl.arange(sz_win)[None, None, :]
27  i_hin = nl.arange(sz_hin)[None, :, None]
28
29  i_wout = nl.arange(sz_wout)[None, None, :]
30  i_hout = nl.arange(sz_hout)[None, :, None]
31
32  # Generate pool index patterns (requires two extra dimensions, for the pool window)
33  i_0 = nl.arange(sz_p)[:, None, None, None, None] #
34  i_1 = nl.arange(sz_hin//sz_pool)[None, :, None, None, None] # y_outer
35  i_2 = nl.arange(sz_pool)[None, None, :, None, None] # y_inner
36  i_3 = nl.arange(sz_win//sz_pool)[None, None, None, :, None] # x_outer
37  i_4 = nl.arange(sz_pool)[None, None, None, None, :] # x_inner
38
39  # Load input data from external memory to on-chip memory
40  # Declare ndarray to force a 3D tensor (temporary requirement)
41  in_tile = nl.ndarray([sz_p, sz_hin, sz_win], dtype=in_tensor.dtype)
42  in_tile[:,:,:] = nl.load(in_tensor[i_p, i_hin, i_win])
43
44  # Perform the pooling operation:
45  # We use numpy's advanced indexing, in order to extend in_tile to 5D, and then reduce-max two dimension.
46  # axis[0] is the index for p_dim, and thus doesn't participate in the reduction operation.
47  # axis[1] and axis[2] together index the rows, with axis[2] responsible for inner strides
48  # (i.e. inside a pooling window), and axis[1] responsible for the outer strides. As such, we reduce over axis[2].
49  # Similarly, axis[3] and axis[4] together index the columns, and we thus reduce over axis[4].
50  out_tile = nl.max(in_tile[i_0, sz_pool*i_1+i_2, sz_pool*i_3+i_4], axis=[2,4])
51
52  # Store the results back to external memory
53  nl.store(out_tensor[i_p, i_hout, i_wout], value=out_tile)
54
55
56if __name__ == "__main__":
57    import torch
58    from torch_xla.core import xla_model as xm
59
60    device = xm.xla_device()
61
62    # Now let's run the kernel
63    POOL_SIZE = 2
64    C, HIN, WIN = 2, 6, 6
65    HOUT, WOUT = HIN//POOL_SIZE, WIN//POOL_SIZE
66
67    in_tensor = torch.arange(C * HIN * WIN, dtype=torch.bfloat16).reshape(C, HIN, WIN).to(device=device)
68    out_tensor = torch.zeros((C, HOUT, WOUT), dtype=torch.bfloat16).to(device=device)
69
70    tensor_maxpool_kernel_(in_tensor, out_tensor, POOL_SIZE)
71
72    print(in_tensor, out_tensor) # an implicit XLA barrier/mark-step

SPMD: Launching multiple instances of a kernel#

So far we have discussed how to launch a single NKI kernel instance, in which the full input tensor is processed. In this section, we discuss how to launch multiple instances of the same kernel and slice the full input tensor across kernel instances using a single program multiple data programming model (SPMD).

Note

In current NKI release, adopting the SPMD programming model has no impact on performance of NKI kernel, and therefore is considered optional. A SPMD program is compiled into an executable that targets one NeuronCore, and the different instances of the SPMD program are executed serially on a single NeuronCore. This is subject to changes in future releases.

NKI allows users to launch multiple instances of a kernel, which are organized in a user-defined multi-dimensional grid. The grid indices are then used by the different kernel instances to select which input and output data to access. There is no restriction on the number of dimensions in an SPMD grid, nor on the size of each dimension. Each kernel instance can find its coordinates within the launch grid using the nki.language.program_id API. Neuron compiler translates the SPMD launch grid into nested loops of compute-kernel invocations, which are then executed on the NeuronCore.

As an example, we’ll perform a C=A@B matrix multiplication, where A and B are of shape (512, 128) and (128, 1024) respectively. We partition the output tensor C of shape (512, 1024) into 4x2 tiles and assign the task of computing each output tile to a different kernel instance. A 4x2 launch-grid is chosen in this case, in order to make each compute kernel instance operate on a single tile in A and a single tile in B, while adhering to the tile-size constraints.

With a 2D 4x2 launch grid, the (i,j) kernel instance is responsible for computing the (i,j) tile of C. The computation of the (i,j) tile requires the corresponding rows of A and columns of B. This induces a four-way row-wise partitioning of A and a two-way column-wise partitioning of B, as shown in Fig. 14.

../../_images/pm-spmd.png

Fig. 14 Visualization of 512x128x1024 matrix multiplication using SPMD#

In this SPMD kernel example, we will use the high-level nki.language.matmul API, so that we can focus on the concept of SPMD without worrying about the layout requirement of Tensor Engine (LC#1). To achieve the best performance, we suggest transposing input A and invoking another NKI kernel instead, which solely performs matmul operations on Tensor Engine using nki.isa.nc_matmul without extra overhead in changing input layouts to meet LC#1.

 1import neuronxcc.nki.language as nl
 2from torch_neuronx import nki_jit
 3
 4@nki_jit
 5def matmul_128x128x512_spmd(A, B, result):
 6  """NKI kernel to compute a 128x128x512 matrix multiplication operation.
 7     Use SPMD program IDs to index into the full A and B input tensor to get tiles
 8     for 128x128x512 matrix multiplication.
 9
10  Args:
11      A: an input tensor of shape [M=512,K=128],
12         a left hand side argument of the matrix multiplication,
13      B: an input tensor of shape [K=128,N=1024],
14         a right hand side argument of the matrix multiplication
15      result: the resulting output tensor of shape [M=128,N=512]
16  """
17  # Defining starting indexes for input A and B
18  i_A_row = nl.program_id(0) * 128
19  i_B_col = nl.program_id(1) * 512
20
21  # Loading the inputs (HBM->SBUF)
22  A_tile = nl.load(A[i_A_row:i_A_row+128, 0:128])
23  B_tile = nl.load(B[0:128, i_B_col:i_B_col+512])
24
25  # Perform the matrix-multiplication
26  # Note1: nl.matmul will invoke a transpose on A_tile before performing the actual matmul operation
27  # Note2: A NKI matmul instruction always writes to PSUM in float32 data-type
28  result_psum = nl.matmul(A_tile, B_tile)
29
30  # Copy the result from PSUM back to SBUF, and cast to expected output data-type
31  result_sbuf = nl.copy(result_psum, dtype=result.dtype)
32
33  # The result of a [128,128] x [128,512] matrix multiplication has a shape of [128, 512].
34  # This dictates which indices to use to address the result tile.
35  nl.store(result[i_A_row:i_A_row+128, i_B_col:i_B_col+512], value=result_sbuf)
36
37
38if __name__ == "__main__":
39  from torch_xla.core import xla_model as xm
40  import torch
41
42  device = xm.xla_device()
43
44  A = torch.ones((512, 128), dtype=torch.bfloat16).to(device=device)
45  B = torch.ones((128, 1024), dtype=torch.bfloat16).to(device=device)
46  result = torch.zeros((512, 1024), dtype=torch.bfloat16).to(device=device)
47
48  # Launch kernel with a 2D grid
49  matmul_128x128x512_spmd[4, 2](A, B, result)
50
51  print(result) # an implicit XLA barrier/mark-step

This document is relevant for: Inf2, Trn1, Trn1n