*This document is relevant for*: `Inf2`

, `Trn1`

, `Trn1n`

# Fused Mamba#

In this tutorial, we implement a NKI kernel for the Mamba Large Language Model, a State Space Model (SSM) which replaces the attention of a regular Transformer model with a custom layer inspired by Recurrent Neural Networks. We will walk through the core computation step-by-step and map it to NKI APIs to form a functional kernel. Next, by scaling the input shapes of the kernel (both channel size and sequence length), we will iterate on a more hardware-efficient kernel implementation to improve the scaling efficiency.

In this tutorial, we learn about:

Mapping different vector operations efficiently to NeuronCore compute engines, such as associative scan and element-wise operations between tensors

Leveraging data reuse and tiling to reduce excessive data movement and keep compute engines busy

Using neuron-profile to identify performance bottlenecks and opportunities

## PyTorch Reference Implementation#

Before jumping to NKI, let’s examine the compute definition of a Mamba-v1 layer using the below PyTorch script
(`mamba_torch.py`

):

```
1import torch
2import torch_neuronx
3import torch_xla.core.xla_model as xm
4import os
5import argparse
6
7os.environ["NEURON_FRAMEWORK_DEBUG"] = "1"
8os.environ["NEURON_CC_FLAGS"]= " --model-type=transformer --disable-dge "
9
10
11def associative_scan(deltaA, deltaB_u):
12 """
13 Args:
14 deltaA: [batch_size, channels, state_size, seq_len]
15 deltaB_u: [batch_size, channels, state_size, seq_len]
16
17 Mamba uses an associative scan operator to aggregate information across
18 time sequentially (sequence length, e.g. sequence of tokens),
19 from the past to the present.
20 """
21 batch_size, channels, state_size, seq_len = deltaA.shape
22 out = torch.empty(batch_size, channels, state_size, seq_len,
23 device=deltaA.device, dtype=deltaA.dtype)
24 for i in range(seq_len):
25 prev_state = out[..., i - 1] if i > 0 else 0
26 out[..., i] = deltaA[..., i] * prev_state + deltaB_u[..., i]
27 return out
28
29
30def mamba_layer(delta, A, B, u, C):
31 """
32 Args:
33 delta: [batch, channels, seq_len]
34 u: [batch, channels, seq_len]
35 A: [channels, state_size]
36 B: [batch, state_size, seq_len]
37 C: [batch, state_size, seq_len]
38 """
39 # expand the tensors so they all have the same dimensions and compute elementwise products (with broadcast)
40 # deltaA and deltaB_u have shape [batch_size, channels, state_size, seq_len]
41 deltaA = torch.exp(delta[:, :, None, :] * A[None, :, :, None])
42 deltaB_u = delta[:, :, None, :] * B[:, None, :, :] * u[:, :, None, :]
43 scan_res = associative_scan(deltaA, deltaB_u)
44 # y sums over the `state_size` axis and has shape [batch_size, channels, seq_len]
45 mamba_out = (C[:, None, :, :] * scan_res).sum(dim=-2)
46 return mamba_out
47
48
49def parse_args():
50 parser = argparse.ArgumentParser(
51 """Run Mamba PyTorch implementation. Hard-coded small example only since
52 PyTorch implementation is very slow for larger configs.
53 """)
54 parser.add_argument("--mode",
55 choices=["accuracy", "perf"],
56 default="accuracy",
57 help="""Do accuracy test or perf test.
58 Accuracy test compares mamba_v1 kernel against PyTorch implementation.
59 Perf test will generate a NEFF for the PyTorch implementation in local directory
60 for a manual run of neuron-profile.
61 """)
62 args = parser.parse_args()
63 return args
64
65
66if __name__ == "__main__":
67 args = parse_args()
68
69 # Toy example
70 batch = 1
71 seq_len = 512
72 channels = 256
73 state_size = 16
74
75 dtype = torch.float32
76
77 device = xm.xla_device()
78
79 delta = torch.ones(batch, channels, seq_len, dtype=dtype, device=device)
80 u = torch.ones(batch, channels, seq_len, dtype=dtype, device=device)
81
82 # For numerical accuracy testing purposes, we choose negative numbers for A on purpose.
83 # Otherwise, the associative scan will integrate too fast and overflow, which would
84 # mask any real numerical issues in our computation.
85 # A negative A will ensure we catch numerical issues when we have them.
86 A = -torch.ones(channels, state_size, dtype=dtype, device=device)
87 B = torch.ones(batch, state_size, seq_len, dtype=dtype, device=device)
88
89 C = torch.ones(batch, state_size, seq_len, dtype=dtype, device=device)
90
91 xm.mark_step()
92 torch_out = mamba_layer(delta, A, B, u, C)
93 xm.mark_step()
94 print(torch_out)
```

The input tensor shapes are as follows:

`delta: [batch, channels, seq_len]`

`u: [batch, channels, seq_len]`

`A: [channels, state_size]`

`B: [batch, state_size, seq_len]`

`C: [batch, state_size, seq_len]`

The key model parameters are:

`batch`

: batch size of the model.`seq_len`

: sequence length of the model.`channels`

: hidden size of a token.`state_size`

: number of model states.

We use `[batch=1, seq_len=512, channels = 256, state_size = 16]`

as a simple test case for initial performance evaluation.

Running the above Python script will compile the `PyTorch`

compute graph using Neuron Compiler and generate a Neuron executable
file (NEFF) in the same directory. We can then profile the NEFF on a single NeuronCore using neuron-profiler.
Figure below is a screenshot of the profile. We see this initial PyTorch implementation takes **151.83 ms** to execute *on
device*.

Zooming into a portion of the profile, we notice the compute activities on different engines (TensorE/VectorE/ScalarE/GpSimdE) are quite sparse compared to data movement activities (the qSyncIO0 and qVectorSpillReload rows):

In this seemingly “memory-bound” execution trace, the achieved DMA throughput is also extremely low, hovering around
0.33% utilization throughout execution. Therefore, we are stressing neither the compute nor the memory subsystem, hinting
the workload is running at low efficiency on the NeuronCore. In the rest of this tutorial, we will showcase how to re-write
the above computation using NKI to achieve a device execution latency of **172.93 usec** , which is a **878x speedup**
compared to the PyTorch reference implementation.

## Mapping Mamba Layer to NeuronCore#

In this section, we will discuss how the computation can be mapped onto the NeuronCore architecture. We will also highlight the importance of choosing appropriate data layouts to achieve good compute efficiency.

Recall we have the following input tensor shapes in device memory:

`delta: [batch_size, channels, seq_len]`

`u: [batch_size, channels, seq_len]`

`A: [channels, state_size]`

`B: [batch_size, state_size, seq_len]`

`C: [batch_size, state_size, seq_len]`

In fact, the above tensor layout has been chosen carefully based on the computation done in NeuronCore, which we will discuss in more detail below.

In Mamba models, both `seq_len`

and `channels`

are typically in the thousands (such as `seq_len=16K, channels=4K`

),
while `batch_size`

and `state_size`

are much smaller by 2-3 order of magnitudes (such as `batch_size=4, state_size=16`

).
To simplify visualization of computation
on multi-dimensional tensors, let’s hold `batch`

and `state_size`

dimension constant and focus on computation per batch
per state. Note, the `batch_size`

dimension is considered a fully parallel axis in a Mamba layer, while `state_size`

is only a partial parallel axis where results from different states will be accumulated together.

By extracting `batch`

and `state_size`

dimensions, we get the following input tensor shapes in device memory:

`delta_i: [channels, seq_len]`

`u_i: [channels, seq_len]`

`A_i: [channels]`

`B_i: [seq_len]`

`C_i: [seq_len]`

Next, let’s visualize the data flow and computation using 2D matrices or vectors step-by-step.

### Step 1: Element-wise multiplication of `delta_i`

and `A_i`

#

We have the following PyTorch reference code for Step 1:

```
# delta[batch, channels, seq_len]
# A [channels, state_size]
delta[:, :, None, :] * A[None, :, :, None]
# Holding batch and state_size constant
# delta_i: [channels, seq_len]
# A_i: [channels]
delta_i[:, :] * A_i[:]
```

After the above transformation, the multiplication between `delta_i`

and `A_i`

involves a **broadcasting** across the
`seq_len`

dimension of `delta_i`

. In NKI, free-dimension broadcast can often be folded into the actual computation instruction
at no additional performance cost, while partition-dim broadcast often requires a separate instruction on TensorE (see TensorE
alternative use case in Trainium/Inferentia2 Architecture Guide).
As a result, we have two options for executing Step 1.

**Option 1: Map ``seq_len`` to free dimension.** Element-wise multiplication of `delta_i`

and `A_i`

on NeuronCore can
be done through nisa.tensor_scalar
on either VectorE or ScalarE, which automatically broadcast `A_i`

along the free dimension to match the `seq_len`

dimension
in `A_i`

.

Note, the `channels`

dimension is mapped to SBUF partition dimension. Since the input `channels`

dimension has a size
of 256 in our initial setup, which exceeds the architectural limitation of `nl.tile_size.pmax=128`

, we must **tile**
`delta_i`

in the `channels`

dimension (tiled dimension denoted as `channels_tiled`

) and feed one tile into `nisa.tensor_scalar`

at a time. Figure below illustrates the computation done for Option 1.

As an example, the associated NKI code for batch `i_batch`

, state `i_state`

and tile `i_tile_channels`

in `channels`

is:

```
# Input shape in device memory matches the computation layout
# Device memory layout:
# delta_i: [channels, seq_len]
# A_i: [channels]
# Computation layout in SBUF:
# delta_i: [par_dim(channels), seq_len]
# A_i: [par_dim(channels)]
deltaA_i = nisa.tensor_scalar(delta_i, op0=nl.multiply, operand0=A_i)
```

Note, with this compute layout option, the `delta_i`

tensor shape `[channels, seq_len]`

in device memory can be loaded
into SBUF efficiently with `seq_len`

as the free dimension and fed into VectorE/ScalarE for computation. No extra transposes
are needed.

**Option 2: Map ``seq_len`` to partition dimension.** Alternatively, if we choose a transposed layout for `delta_i`

in
SBUF for computation, we will need a partition-dimension broadcast of `A_i`

using a separate instruction on TensorE
(`A_i.broadcast_to(...)`

) and then a nisa.tensor_tensor
operation between `delta_i`

and the broadcast `A_i`

on VectorE. As a reminder, we need to tile the `seq_len`

dimension
to meet the tile size constraint `nl.tile_size.pmax=128`

. Figure below illustrates the computation done for Option 2.

The associated NKI code is as follows:

```
# Input shape in device memory does NOT match the computation layout
# Device memory layout:
# delta_i: [channels, seq_len]
# A_i: [channels]
# Computation layout in SBUF:
# delta_i: [par_dim(seq_len_tiled), channels]
# A_i: [par_dim(1), channels]
A_i_bcast = A_i.broadcast_to((nl.tile_size.pmax, channels))
deltaA_i = nisa.tensor_tensor(delta_i, A_i_bcast, op=ml.multiply)
```

Assuming the same `delta_i`

device memory layout `[channels, seq_len]`

, before performing the `nisa.tensor_tensor`

instruction, we will need to either:

Do a regular load of

`delta_i`

into SBUF using nl.load and an explicit transpose on the loaded`delta_i`

using`nl.transpose`

to make`seq_len`

lie in the free dimension, orDo a transposed load of

`delta_i`

using nl.load_transpose2d, which is significantly less efficient in memory bandwidth usage compared to`nl.load`

If Option2 was chosen as the compute layout, we would have incentives to define the `delta`

input tensor shape as ```
[seq_len,
channels]
```

in device memory instead.

From computation perspectives, Option 2 is less efficient than Option 1 because:

Option 2 needs an extra TensorE instruction performing partition dimension broadcast.

`nisa.tensor_tensor`

is 2x slower than`nisa.tensor_scalar`

for our input data type FP32 (see API doc for instruction cost estimates).

Therefore, for Step 1 only, Option 1 is the winner compared to Option 2. Let’s continue with the rest of the steps to see if we need to revise this selection due to surrounding operator layout preferences.

### Step 2: Exponential of deltaA_i.#

Step 2 is evaluating exponential on `deltaA_i`

from the previous step:

```
torch.exp(...)
```

In NeuronCore, evaluating an exponential function on a tensor is considered a scalar operation, which runs on ScalarE. This
operation can be invoked through nl.exp
or nisa.activation.
However, ScalarE is able to perform a “pipelined multiply-add” on the input before evaluating a non-linear function (detail
see Trainium/Inferentia2 Architecture Guide).
In other words, we can fold Step 1 (Option 1) `nisa.tensor_scalar`

and Step 2 into a single ScalarE instruction at
no additional cost. This functionality is only exposed in the `nisa.activation`

API. This folding is not feasible if we
chose Option 2 `nisa.tensor_tensor`

in Step 1. Figure below illustrates our new execution plan to combine Step 1 and 2
into `nisa.activation`

:

The associated NKI code is as follows:

```
# Input shape in device memory matches the computation layout
deltaA_i = nisa.activation(op=nl.exp, data=delta_i, scale=A_i)
```

### Step 3: Element-wise multiplication of delta_i, B_i and u_i.#

PyTorch reference code for Step 3 is:

```
# delta[batch, channels, seq_len]
# B: [batch, state_size, seq_len]
# u: [batch, channels, seq_len]
delta[:, :, None, :] * B[:, None, :, :] * u[:, :, None, :]
# Holding batch and state_size constant
# delta_i: [channels, seq_len]
# B_i: [seq_len]
# u_i: [channels, seq_len]
delta_i[:, :] * B_i[None, :] * u_i[:, :]
```

This step involves similar compute layout and instruction choices as Step 1:

`channels`

is either partition or free dimension for both`delta_i`

and`u_i`

multiplication with

`B_i`

is either through`nisa.tensor_tensor`

or`nisa.tensor_scalar`

Since we preferred Step 1 to consume `delta_i`

using `channels`

as the partition dimension in previous steps, it is
wise to follow the same layout choice here for `delta_i`

to avoid any transposes. Given this layout choice, the multiplication
with `B_i`

will have to be a `nisa.tensor_tensor`

. Figure below visualizes the computation in Step 3:

The associated NKI code is as follows:

```
# Input shape in device memory does NOT match the computation layout
# Device memory layout:
# delta_i: [channels, seq_len]
# u_i: [channels, seq_len]
# B_i: [seq_len]
# Computation layout in SBUF:
# delta_i: [par_dim(channels_tiled), seq_len]
# u_i: [par_dim(channels_tiled), seq_len]
# B_i: [par_dim(1), seq_len]
deltaU_i = nisa.tensor_tensor(delta_i, u_i, op=ml.multiply)
B_i_bcast = B_i.broadcast_to((nl.tile_size.pmax, seq_len))
deltaBu_i = nisa.tensor_tensor(deltaU_i, B_i_bcast, op=ml.multiply)
```

### Step 4: Associative scan between deltaA_i and deltaBu_i#

In this step, we use an associative scan operator between `deltaA`

and `deltaBu`

to aggregate information across time
sequentially (sequence length, e.g. sequence of tokens), from the past to the present. Here is a PyTorch reference implementation:

```
# deltaA: [batch_size, channels, state_size, seq_len]
# deltaB_u: [batch_size, channels, state_size, seq_len]
out = torch.empty(batch_size, channels, state_size, seq_len,
device=deltaA.device, dtype=deltaA.dtype)
for i in range(seq_len):
# starting state is 0
prev_state = out[..., i - 1] if i > 0 else 0
# multiply deltaA by the previous time step state and then add deltaB_u
out[..., i] = deltaA[..., i] * prev_state + deltaB_u[..., i]
```

By holding batch and state_size dimensions constant, we get `deltaA_i`

and `deltaBu_i`

both with
`[channels_tiled, seq_len]`

, where `channels_tiled`

is the partition dimension.
The associative scan between these two tile shapes can
be implemented in NKI naively through the following loop:

```
scan_i = nl.ndarray((channels_tiled, seq_len), ...)
# Peeling the first iteration out, which is
# equivalent to loop iterator dependent control flow within the loop
scan_i[0:channels_tiled, 0] = deltaBu[0:channels_tiled, 0]
for i in nl.sequential_range(seq_len - 1):
scan_i[0:channels_tiled, i+1] = deltaA_i[0:channels_tiled, i+1] * scan_i[0:channels_tiled, i]
+ deltaBu_i[0:channels_tiled, i+1]
```

Within the loop, the current implementation invokes one instruction for multiplication and another for addition. Since both
instructions are performed among tiles of shape `[channels_tiled, 1]`

, we can combine
these two instructions using nisa.tensor_scalar
which supports two operators in a pipelined fashion within an instruction at the same cost as a single operator. Below is
a new implementation that could provide 2x speedup compared to the above:

```
scan_i = nl.ndarray((channels_tiled, seq_len), dtype=deltaA.dtype, buffer=nl.sbuf)
scan_i[0:channels_tiled, 0] = deltaBu[i_p, 0]
for i in nl.sequential_range(seq_len - 1):
scan_i[0:channels_tiled, i+1] = nisa.tensor_scalar(
deltaA[0:channels_tiled, i+1],
op0=nl.multiply,
operand0=scan_i[0:channels_tiled, i],
op1=nl.add,
operand1=deltaBu[0:channels_tiled, i+1])
```

However, the above loop nest will turn into `seq_len`

many instructions with input tiles that have a single element per
partition in SBUF. In addition, every `nisa.tensor_scalar`

instruction has a data dependency on the output of the previous
instruction. As discussed in the Trainium/Inferentia2 Architecture Guide,
these two traits combined in the instruction sequence is considered extremely *inefficient* on ScalarE/VectorE, where
the static instruction overhead instead of the useful execution time would be dominating the engine timeline.

Conveniently, NKI exposes another instruction nisa.tensor_tensor_scan
on VectorE, which can perform the above loop nest in a *single* instruction by caching the intermediate scan result from
the previous time step internally in VectorE without going through SBUF.

```
scan_i = nisa.tensor_tensor_scan(deltaA_i, deltaBu_i, initial=0,
op0=np.multiply, op1=np.add)
```

Note, the shape of `scan_i`

is exactly the same as the input `deltaA_i/deltaBu_i`

: `[channels_tiled, seq_len]`

.

### Step 5: Element-wise multiplication of C_i and scan_i#

The PyTorch reference implementation is:

```
# scan_res: [batch_size, channels, state_size, seq_len]
# C: [batch_size, state_size, seq_len]
scanC = C[:, None, :, :] * scan_res
# Holding batch and state constant
# scan_i: [channels_tiled, seq_len]
# C_i: [seq_len]
scanC_i = C_i[None, :] * scan_i[:, :]
```

You know the drill - Since `channels_tiled`

is the partition dimension in `scan_i`

from the previous step, we need to
perform a partition-dimension broadcast on `C_i`

before invoking `nisa.tensor_tensor`

:

The corresponding NKI code is:

```
C_i_bcast = C_i.broadcast((nl.tile_size.pmax, seq_len))
scanC_i = nisa.tensor_tensor(scan_i, C_i_bcast, op=ml.multiply)
```

### Step 6: Accumulation of scanC_i along `state_size`

dimension#

So far in Step 1-5, all the computation is logically parallel across the `state_size`

dimension in a Mamba layer. The
next step of computation introduces data dependency along the `state_size`

dimension for the first time. The PyTorch reference
implementation is:

```
# scan_res: [batch_size, channels, state_size, seq_len]
# C: [batch_size, state_size, seq_len]
# -2 dim is state_size
scanC.sum(dim=-2)
# Holding batch constant only.
# scan_i_states: [channels_tiled, state_size, seq_len]
(scanC_i).sum(dim=-2)
```

In NKI, we can accumulate the `scanC_i`

results across states element-wise using `state_size-1`

number of `nisa.tensor_tensor`

instructions:

Since we will be looping over different states, we can also declare an empty accumulation buffer `scanC_accum`

of shape
`[channels_tiled, seq_len]`

outside of the loop structure and accumulate into this buffer at the end of the every loop
iteration using `+=`

operator. The use of a single accumulation buffer avoids allocating memory for `scanC_i`

across
all states in SBUF. The corresponding NKI code is:

```
scanC_accum = nl.zeros(...)
for i_state in nl.affine_range(state_size):
scanC_i = ...
scanC_accum += scanC_i
```

## Initial NKI Kernel#

Putting all the pieces together from the previous section, we can arrive at the below kernel implementation `mamba_v1`

:

```
1import neuronxcc.nki as nki
2import neuronxcc.nki.language as nl
3import neuronxcc.nki.isa as nisa
4import numpy as np
5
6def mamba_v1(delta, u, A, B, C, output):
7 """Computes the SSM operation in the Mamba model.
8
9 :param delta: (batch_size, channels, seq_len)
10 :param u: (batch_size, channels, seq_len)
11 :param A: (channels, state_size)
12 :param B: (batch_size, state_size, seq_len)
13 :param C: (batch_size, state_size, seq_len)
14 :return: (batch_size, channels, seq_len)
15 """
16 batch_size, channels, seq_len = delta.shape
17 _, state_size = A.shape
18
19 # We can relax this using mask paramters in all the NKI API calls
20 assert channels % 128 == 0
21
22 # Map channels to the partition dimension
23 # Tile channels to comply with NKI tile size constraints
24 channel_psize = nl.tile_size.pmax
25 n_channel_tile = channels // channel_psize
26
27 # Most outer loop with batch_size, parallel_for
28 for i_batch in nl.affine_range(batch_size):
29 # partial accumulated scanC result with processed states
30 scanC_accum = nl.zeros((n_channel_tile, nl.par_dim(channel_psize), seq_len), dtype=delta.dtype)
31
32 # Second outer loop with state_size, partial parallel
33 for i_state in nl.affine_range(state_size):
34
35 # Inner loop: tiling channels
36 for i_channel_tile in nl.affine_range(n_channel_tile):
37 channel_start = i_channel_tile * channel_psize
38
39 # Load the relevant tile from delta and A
40 delta_i = nl.load(delta[i_batch, channel_start:channel_start+channel_psize, 0:seq_len])
41 A_i = nl.load(A[channel_start:channel_start+channel_psize, i_state])
42
43 # Step 1&2: Element-wise multiplication of delta_i and A_i and then exponential
44 deltaA = nisa.activation(op=nl.exp, data=delta_i, scale=A_i)
45
46 # Load the relevant tile from u and B
47 u_i = nl.load(u[i_batch, channel_start:channel_start+channel_psize, 0:seq_len])
48 B_i = nl.load(B[i_batch, i_state:i_state+1, 0:seq_len])
49
50 # Step 3: Element-wise multiplication of delta_i, B_i and u_i
51 deltaU = nisa.tensor_tensor(delta_i, u_i, op=nl.multiply)
52 B_i_bcast = B_i.broadcast_to((channel_psize, seq_len))
53 deltaBu = nisa.tensor_tensor(deltaU, B_i_bcast, op=nl.multiply)
54
55 # Step 4: Associative scan between deltaA and deltaBu
56 scan_res = nki.isa.tensor_tensor_scan(deltaA, deltaBu, initial=0,
57 op0=np.multiply, op1=np.add)
58
59 # Load the relevant tile from C
60 C_i = nl.load(C[i_batch, i_state:i_state+1, 0:seq_len])
61
62 # Step 5: Element-wise multiplication of scan_res and C_i
63 C_i_bcast = C_i.broadcast_to((channel_psize, seq_len))
64 scanC = nisa.tensor_tensor(scan_res, C_i_bcast, op=nl.multiply)
65
66 # Step 6: Accumulation of scanC along state_size dimension
67 # scanC_accum[i_channel_tile, 0:channel_psize, 0:seq_len] = nisa.tensor_tensor(
68 # scanC_accum[i_channel_tile, 0:channel_psize, 0:seq_len], scanC, op=nl.add)
69 scanC_accum[i_channel_tile, 0:channel_psize, 0:seq_len] += scanC
70
71 # Store scanC_accum for a single batch to output
72 for i_channel_tile in nl.affine_range(n_channel_tile):
73 channel_start = i_channel_tile * channel_psize
74 nl.store(output[i_batch, channel_start:channel_start+channel_psize, 0:seq_len],
75 scanC_accum[i_channel_tile, 0:channel_psize, 0:seq_len])
76
```

In the above code example,

- We have three levels of loop nests. From the outer-most to inner-most:
Iterating over

`batch`

: Different batch samples perform completely different computation.`A`

tensor is the only input parameter that is shared among batch samples.Iterating over

`state_size`

: Different states perform parallel computation until Step 6 as discussed in the previous section. Both`delta`

and`u`

tensors are shared across different states.Iterating over

`channels`

: This is the most-inner dimension where we tile the input channels dimension into`nl.tile_size.pmax=128`

chunks. Both`B`

and`C`

tensors are shared across different`channels`

.

The kernel above assumes channels is a multiple of

`nl.tile_size.pmax=128`

. We can relax this by adding a`mask`

parameter in all the NKI API call in the kernel. To simplify the code example, we omit this change. See NKI API Masking for more information.We declare an empty intermediate tensor

`scanC_accum`

to hold partial summation from every state.- Within the inner loop, we process data for
`nl.tile_size.pmax=128`

channels for one batch sample in one state. We use the slicing syntax to index a tensor. For example,

`delta[i_batch, channel_start:channel_start+channel_psize, 0:seq_len]`

grabs data from the input`delta`

tensor for the current range of channels at the current batch sample.Note, in tensor slicing, the first index dimension from the left with a slicing range will be chosen as the partition dimension. When loading

`B`

, since we intend to load only one state’s worth of data into one partition of SBUF (discussed in Step 3), we need to explicitly slice the state using:`nl.load(B[i_batch, **i_state:i_state+1**, 0:seq_len])`

. Otherwise,`nl.load(B[i_batch, **i_state**, 0:seq_len])`

will treat`seq_len`

as the partition dimension, which is not what we planned for in Step 3 and would also trigger a NKI compilation error since`seq_len`

exceeds`nl.tile_size.pmax`

.We accumulate partial

`scanC_i`

results into the accumulation buffer using the`+=`

operator. This creates a loop-carried dependency for`scanC_accum`

on the`i_state`

loop.

- Within the inner loop, we process data for

### Performance Check#

Let’s re-run neuron-profile on the above NKI kernel:

Hooray! This NKI kernel implementation now takes `172.93`

usec, which is **878x** speedup compared to the reference PyTorch
implementation. Based on the profile, VectorE is the busiest compute engine in the Mamba layer. This makes sense because
the bulk of computation in the kernel is in `nisa.tensor_tensor`

, which can only run on VectorE.

Therefore, our goal is to keep VectorE as busy as possible throughout execution. Note, every NEFF execution involves certain
start-up and tear-down overhead. We can use the `Selection Summary`

feature in `neuron-profile`

to find out the percentage
of time VectorE is busy during the actual execution period:

As indicated by the above profile, VectorE is active over **98.71%** of the time, which is rather impressive. However,
remember we used small input shapes as a toy example to get started: `[batch=1, seq_len=512, channels = 256, n = 16]`

.
Next, let’s increase the `channels`

and `seq_len`

dimensions one by one and observe how VectorE efficiency changes.

## Increasing input `channels`

size#

Let’s increase the size of `channels`

by 16x, from 256 to a more realistic value 4096. We obtain the following profile:

The new device execution time with increased channels is now **2.34 ms**. We can see that VectorE active duration has
dropped to **92.16%** during the core execution period, compared to **98.71%** previously with the toy example. Let’s zoom
into an arbitrary region of the profile to see what could be causing VectorE to go idle:

By identifying a gap where VectorE is completely idle, we can hover over the first executed instruction after the gap
to find out what’s the reason for idleness in the instruction semaphore wait condition. In the above screenshot, the instruction
is pending on `S[22]`

to reach a value of 240, which is set by `qSyncIO0`

activities. This means VectorE has been waiting
for input tensors to be loaded before performing more computation. If you hover over `qSyncIO0`

activities during the
VectorE idle period, you can also see the exact input tensor name defined in NKI being loaded in the DMA:

We can find similar VectorE gaps through the execution trace. At this point, we can conclude one of the reasons why we have
a lower VectorE active time percentage is due to *blocking* input tensor loading (`nl.load`

) activities in the DMA.
Next, let’s spend some time analyzing DMA efficiency.

Zooming out, we can make several observations. First, we see two orange boxes around the `qSyncIO0`

row. Hovering over
the top left corners of the boxes shows two similar performance warnings for loading IO tensors:

This indicates we reload both the input `u`

and `delta`

tensors around 7 times. This could be inevitable
when we don’t have sufficient on-chip memory (SBUF) to allow full reuse of the input data tensors. However, the profiler
shows we are only hitting around 50% capacity usage throughout execution:

Therefore, the input tensor reloading is likely not justified, and we should investigate whether we can optimize the NKI kernel to avoid it.

### Minimizing data reloading by loop reordering#

To understand why delta and u are being reloaded, let’s revisit our input tensor shapes:

`delta: [batch_size, channels, seq_len]`

`u: [batch_size, channels, seq_len]`

`A: [channels, state_size]`

`B: [batch_size, state_size, seq_len]`

`C: [batch_size, state_size, seq_len]`

Let’s hold `batch_size`

constant since the majority of input tensors have completely different slices for different batch
samples:

`delta: [channels, seq_len]`

`u: [channels, seq_len]`

`A: [channels, state_size]`

`B: [state_size, seq_len]`

`C: [state_size, seq_len]`

`delta`

and `u`

tensors have the same shape with `channels`

as the outer dimensions, while `B`

and `C`

have the
same shape with `state_size`

as the outer dimension. All four of these input tensors have `seq_len`

as the inner dimension.
Therefore, we say `delta/u`

is reused across different states, while `B/C`

are reused across different channels. Given
this conflicting reuse dimensions, we further say it is more important to **prioritize reuse of ``delta/u``** because
the expected size of `channels`

is much higher than `state_size`

:

`state_size`

is now 16 and typically stay small`channels`

is now 4096 and typically in the thousands

In NKI, we can prioritize `delta/u`

reuse through loop ordering. Recall in the initial NKI kernel implementation, we have
the following inner loops:

```
...
for i_state in nl.affine_range(state_size):
for i_channel_tile in nl.affine_range(n_channel_tile):
# step 1-6
...
```

Since these two loops are executed serially within a single NeuronCore, the loop instances will be unrolled by Neuron Compiler.
With the channel dimension in the fastest dimension, we will need to load `delta/u`

across all channels in the first state,
and then likely reload them again in the later states due to a large total memory size in `delta`

and `u`

(16MB in this
case).

To prioritize reuse of `delta/u`

, we should reorder the above loop nests. To further enforce the reuse, we can hoist
the `nl.load`

calls for `delta/u`

outside of the `i_state`

inner loop:

```
...
for i_channel_tile in nl.affine_range(n_channel_tile):
delta_i = nl.load(...)
u_i = nl.load(...)
for i_state in nl.affine_range(state_size):
# step 1-6
...
```

As a side effect of this loop re-ordering, we can also spot a loop fusion opportunity since we have two `i_channel_tile`

loop nests at the same level now:

```
scanC_accum = nl.zeros((n_channel_tile, nl.par_dim(channel_psize), seq_len), ...)
...
# First i_channel_tile loop
for i_channel_tile in nl.affine_range(n_channel_tile):
delta_i = nl.load(...)
u_i = nl.load(...)
for i_state in nl.affine_range(state_size):
# step 1-6
# Second i_channel_tile loop
for i_channel_tile in nl.affine_range(n_channel_tile):
nl.store(..., scanC_accum[i_channel_tile, 0:channel_psize, 0:seq_len])
...
```

By fusing the two `i_channel_tile`

loop nests into a single loop nest, we can pull the declaration of `scanC_accum`

inside the `i_channel_tile`

loop and further reduce the `scanC_accum`

size requirement by a factor of `n_channel_tile`

:

```
...
# First i_channel_tile loop
for i_channel_tile in nl.affine_range(n_channel_tile):
scanC_accum = nl.zeros((nl.par_dim(channel_psize), seq_len), ...)
delta_i = nl.load(...)
u_i = nl.load(...)
for i_state in nl.affine_range(state_size):
# step 1-6
nl.store(..., scanC_accum[i_channel_tile, 0:channel_psize, 0:seq_len])
...
```

Let’s modify our initial NKI kernel implementation accordingly to get `mamba_v2`

:

```
1def mamba_v2(delta, u, A, B, C, output):
2 """Computes the SSM operation in the Mamba model.
3
4 :param delta: (batch_size, channels, seq_len)
5 :param u: (batch_size, channels, seq_len)
6 :param A: (channels, state_size)
7 :param B: (batch_size, state_size, seq_len)
8 :param C: (batch_size, state_size, seq_len)
9 :return: (batch_size, channels, seq_len)
10 """
11 batch_size, channels, seq_len = delta.shape
12 _, state_size = A.shape
13
14 assert channels % 128 == 0
15
16 # Map channels to the partition dimension
17 # Tile channels to comply with NKI tile size constraints
18 channel_psize = nl.tile_size.pmax
19 n_channel_tile = channels // channel_psize
20
21 # Most outer loop with batch_size, parallel_for
22 for i_batch in nl.affine_range(batch_size):
23
24 # Second outer loop: tiling channels
25 for i_channel_tile in nl.affine_range(n_channel_tile):
26 channel_start = i_channel_tile * channel_psize
27
28 # partial accumulated scanC result with processed states
29 scanC_accum = nl.zeros((nl.par_dim(channel_psize), seq_len), dtype=delta.dtype)
30
31 # Load delta/u once to be reused across states
32 delta_i = nl.load(delta[i_batch, channel_start:channel_start+channel_psize, 0:seq_len])
33 u_i = nl.load(u[i_batch, channel_start:channel_start+channel_psize, 0:seq_len])
34
35 # Inner loop with state_size, partial parallel
36 for i_state in nl.affine_range(state_size):
37 # Load the relevant tile from A
38 A_i = nl.load(A[channel_start:channel_start+channel_psize, i_state])
39
40 # Step 1&2: Element-wise multiplication of delta_i and A_i and then exponential
41 deltaA = nisa.activation(op=nl.exp, data=delta_i, scale=A_i)
42
43 # Load the relevant tile from B
44 B_i = nl.load(B[i_batch, i_state:i_state+1, 0:seq_len])
45
46 # Step 3: Element-wise multiplication of delta_i, B_i and u_i
47 deltaU = nisa.tensor_tensor(delta_i, u_i, op=nl.multiply)
48 B_i_bcast = B_i.broadcast_to((channel_psize, seq_len))
49 deltaBu = nisa.tensor_tensor(deltaU, B_i_bcast, op=nl.multiply)
50
51 # Step 4: Associative scan between deltaA and deltaBu
52 scan_res = nki.isa.tensor_tensor_scan(deltaA, deltaBu, initial=0,
53 op0=np.multiply, op1=np.add)
54
55 # Load the relevant tile from C
56 C_i = nl.load(C[i_batch, i_state:i_state+1, 0:seq_len])
57
58 # Step 5: Element-wise multiplication of scan_res and C_i
59 C_i_bcast = C_i.broadcast_to((channel_psize, seq_len))
60 scanC = nisa.tensor_tensor(scan_res, C_i_bcast, op=nl.multiply)
61
62 # Step 6: Accumulation of scanC along state_size dimension
63 scanC_accum[0:channel_psize, 0:seq_len] += scanC
64
65 # Store scanC_accum for a single batch to output
66 nl.store(output[i_batch, channel_start:channel_start+channel_psize, 0:seq_len],
67 scanC_accum[0:channel_psize, 0:seq_len])
68
```

We recapture the profile for the new kernel implementation:

The device execution time is now **1.61 ms**, which is a **31%** reduction in latency compared to our initial kernel implementation.
We can also see VectorE active duration is back up to 99.63% and the performance warnings on input tensor reloading are
now gone. In case you are curious, the above loop reordering optimization alone provides around 30% of latency reduction,
while the loop fusion optimization contributes the remaining 1% performance boost. This makes sense because the loop reordering
addresses our key performance concern around input data reloading, while reducing intermediate tensor size is only a nice-to-have
given we were quite low on SBUF usage to begin with.

## Increasing input `seq_len`

size#

Next, let’s increase the input `seq_len`

by **16x**, from 512 to 8192 and recompile the above NKI kernel. Below is the
associated performance profile:

The new profile now takes **53.33 ms**, which is **33x longer** than the previous profile. VectorE active duration has
dropped down to a new low: 58.93%. Compared to the profile captured with a smaller `seq_len`

, we notice new DMA activity
rows `qSyncSpillReload0`

and `qVectorSpillReload0`

, which are associated with data movement traffic for intermediate
data spill from SBUF into device memory or reload back to SBUF. Zooming into a smaller portion of the profile:

We can see VectorE enters idle states due to a blocking semaphore wait for `qSyncSpillReload0`

activities,
which indicates the extra spill/reload is indeed degrading overall computation performance. In addition, we can see low
SBUF usage peaking at merely 50%. Computation and data movement are also not overlapped properly, leading to low average
utilization in both compute engines and DMA throughput in the overall timeline.

Intuitively, increasing `seq_len`

of the kernel increases the active tile sizes of input and intermediate tensors in the
free dimension, which could cause severe fragmentations in SBUF and excessive data movements to spill/reload tensors in
SBUF. To mitigate these inefficiencies, we must **tile** the `seq_len`

dimension in our NKI kernel through a new loop
level.

### Mitigate spilling by tiling `seq_len`

#

We have **three** key considerations when adding this new loop level:

tile size selection,

loop-carried dependency handling

loop ordering with other loop nests.

**Tile size of ``seq_len``.** Since previously with `seq_len=512`

in our toy example, we were able to achieve close to
100% VectorE utilization, let’s set the tile size `seq_len_fsize`

to 512 as a starting point. We can revisit this decision
as needed once we obtain a new profile.

**Loop-carried dependency.** Splitting `seq_len`

into chunks is straightforward for all computation steps except for Step
4. In the associative scan operation, the next loop iteration requires results from the previous iteration for computation.
As a result, we will introduce another loop-carried dependency here with the scan tiles. This dependency can be handled
through the `initial`

input parameter:

```
scan_init = nl.zeros((channel_psize, 1), ...)
for i_seq_len_tile in static_range(seq_len // seq_len_fsize):
scan_i = nisa.tensor_tensor_scan(deltaA, deltaBu, initial=scan_init,
op0=np.multiply, op1=np.add)
scan_init = scan_i[0:channel_psize, seq_len_fsize-1]
```

Note, we choose to use `static_range`

instead of `affine_range`

due to the new loop-carried dependencies.

**Loop ordering.** Recall from our latest NKI kernel implementation, we have the following loop nest:

```
...
for i_batch in nl.affine_range(batch_size):
for i_channel_tile in nl.affine_range(n_channel_tile):
scanC_accum = nl.zeros((nl.par_dim(channel_psize), **seq_len**), ...)
delta_i = nl.load(delta[i_batch, channel_start:channel_start+channel_psize, 0:**seq_len**])
u_i = nl.load(u[i_batch, channel_start:channel_start+channel_psize, 0:**seq_len**])
for i_state in nl.affine_range(state_size):
A_i = nl.load(A[channel_start:channel_start+channel_psize, i_state])
B_i = nl.load(B[i_batch, i_state:i_state+1, 0:**seq_len**])
C_i = nl.load(C[i_batch, i_state:i_state+1, 0:**seq_len**])
deltaA = ...
deltaBu = ...
scanC = ...
...
scanC_accum += ...
nl.store(..., scanC_accum[i_channel_tile, 0:channel_psize, 0:**seq_len**])
...
```

Let’s denote the above loop ordering as `[batch_size, n_channel_tile, state_size]`

, and our key question here is where
to insert `seq_len`

in this list.

Appending `seq_len`

to the above list, that is, making `seq_len`

the new inner-most loop, would involve the least amount
of code changes to our current NKI kernel. However, it will lead to the least amount of SBUF usage reduction, since this
loop ordering won’t be tiling `scanC_accum`

, `delta_i`

and `u_i`

tensors. Given `seq_len=8192`

and FP32 data types,
these three tensors will occupy 8192*4B*3 = 96 KiB/partition, half of the available SBUF capacity. Let’s go ahead and
experiment this loop ordering in a new kernel `mamba_v3`

:

With the above profile, the kernel now takes **27.8 ms**, which is **48%** reduction in latency compared to no `seq_len`

tiling. VectorE is now 94.85% active, and we no longer have spilling related DMA activities.

Finally, since the key advantage of Mamba compared to Transformer models is Mamba’s computation and latency should scale
linearly with respect to `seq_len`

, instead of quadratically in Transformers, let’s plot the measured kernel latencies across different
`seq_len`

up to 8K (what we have optimized so far) and compare it against “perfect latencies” assuming linear scaling
from `seq_len=512`

. We evaluate scaling efficiency using `perfect latency / measured latency`

,
which is a higher the better metric. Finally, to showcase the importance of the last seq_len tiling optimization for scaling seq_len,
we also compare scaling efficiency for `mamba_v2`

(no seq_len tiling) and `mamba_v3`

(seq_len tiling).

seq_len |
Perfect Latency (ms) |
mamba_v2 Measured Latency (ms) |
mamba_v2 Scaling Efficiency |
mamba_v3 Measured Latency (ms) |
mamba_v3 Scaling Efficiency |
---|---|---|---|---|---|

512 |
N/A |
1.6 |
N/A |
1.6 |
N/A |

1024 |
3.2 |
4.4 |
72.73% |
3.3 |
96.97% |

2048 |
6.4 |
8.9 |
71.91% |
6.6 |
96.97% |

3072 |
9.6 |
13.1 |
73.28% |
10.1 |
95.05% |

4096 |
12.8 |
17.6 |
72.73% |
13.3 |
96.24% |

5120 |
16 |
23.7 |
67.51% |
17.3 |
92.49% |

6144 |
19.2 |
27.5 |
69.82% |
19.6 |
97.96% |

7168 |
22.4 |
41.3 |
54.24% |
24.2 |
92.56% |

8192 |
25.6 |
52.2 |
49.04% |
27.8 |
92.09% |

The above data shows the last NKI kernel implementation `mamba_v3`

can reach 90%+ scaling efficiency up to 8K `seq_len`

.
To support even larger `seq_len`

, we will need more aggressive tiling by pulling the `seq_len`

loop level further
towards the outer-loop level to tile more input/intermediate tensors to keep spilling low and VectorE busy.

## Download All Source Code#

Click the links to download source code of the kernels and the testing code discussed in this tutorial.

PyTorch reference implementation:

`mamba_torch.py`

Three versions of NKI kernels:

`mamba_nki_kernels.py`

You can also view the source code in the Github repository nki_samples

### Example usage of the scripts:#

**Performance mode**

Run PyTorch reference implementation to generate a NEFF for profiling:

```
python3 mamba_torch.py --mode perf
```

Check performance numbers of mamba_v1/mamba_v2/mamba_v3:

```
python3 mamba_nki_kernels.py --mode perf --version v1 v2 v3 --batch 1 --seq_len 2048 --channels 512 --state_size 16
```

**Accuracy mode**

Check mamba_v1 NKI kernel accuracy against PyTorch implementation:

```
python3 mamba_torch.py --mode accuracy
```

Check optimized Mamba kernel (mamba_v2, mamba_v3) accuracy against mamba_v1:

```
python3 mamba_nki_kernels.py --mode accuracy --version v1 v2 v3 --batch 1 --seq_len 2048 --channels 512 --state_size 16
```

*This document is relevant for*: `Inf2`

, `Trn1`

, `Trn1n`