This document is relevant for: Inf2, Trn1, Trn1n

Fused Mamba#

In this tutorial, we implement a NKI kernel for the Mamba Large Language Model, a State Space Model (SSM) which replaces the attention of a regular Transformer model with a custom layer inspired by Recurrent Neural Networks. We will walk through the core computation step-by-step and map it to NKI APIs to form a functional kernel. Next, by scaling the input shapes of the kernel (both channel size and sequence length), we will iterate on a more hardware-efficient kernel implementation to improve the scaling efficiency.

In this tutorial, we learn about:

  • Mapping different vector operations efficiently to NeuronCore compute engines, such as associative scan and element-wise operations between tensors

  • Leveraging data reuse and tiling to reduce excessive data movement and keep compute engines busy

  • Using neuron-profile to identify performance bottlenecks and opportunities

PyTorch Reference Implementation#

Before jumping to NKI, let’s examine the compute definition of a Mamba-v1 layer using the below PyTorch script (mamba_torch.py):

 1import torch
 2import torch_neuronx
 3import torch_xla.core.xla_model as xm
 4import os
 5import argparse
 6
 7os.environ["NEURON_FRAMEWORK_DEBUG"] = "1"
 8os.environ["NEURON_CC_FLAGS"]= " --model-type=transformer --disable-dge "
 9
10
11def associative_scan(deltaA, deltaB_u):
12    """
13    Args:
14        deltaA: [batch_size, channels, state_size, seq_len]
15        deltaB_u: [batch_size, channels, state_size, seq_len]
16
17    Mamba uses an associative scan operator to aggregate information across
18    time sequentially (sequence length, e.g. sequence of tokens),
19    from the past to the present.
20    """
21    batch_size, channels, state_size, seq_len = deltaA.shape
22    out = torch.empty(batch_size, channels, state_size, seq_len,
23                        device=deltaA.device, dtype=deltaA.dtype)
24    for i in range(seq_len):
25        prev_state = out[..., i - 1] if i > 0 else 0
26        out[..., i] = deltaA[..., i] * prev_state + deltaB_u[..., i]
27    return out
28
29
30def mamba_layer(delta, A, B, u, C):
31    """
32    Args:
33        delta: [batch, channels, seq_len]
34        u: [batch, channels, seq_len]
35        A: [channels, state_size]
36        B: [batch, state_size, seq_len]
37        C: [batch, state_size, seq_len]
38    """
39    # expand the tensors so they all have the same dimensions and compute elementwise products (with broadcast)
40    # deltaA and deltaB_u have shape [batch_size, channels, state_size, seq_len]
41    deltaA = torch.exp(delta[:, :, None, :] * A[None, :, :, None])
42    deltaB_u = delta[:, :, None, :] * B[:, None, :, :] * u[:, :, None, :]
43    scan_res = associative_scan(deltaA, deltaB_u)
44    # y sums over the `state_size` axis and has shape [batch_size, channels, seq_len]
45    mamba_out = (C[:, None, :, :] * scan_res).sum(dim=-2)
46    return mamba_out
47
48
49def parse_args():
50    parser = argparse.ArgumentParser(
51    """Run Mamba PyTorch implementation. Hard-coded small example only since
52       PyTorch implementation is very slow for larger configs.
53    """)
54    parser.add_argument("--mode",
55                        choices=["accuracy", "perf"],
56                        default="accuracy",
57                        help="""Do accuracy test or perf test.
58                                Accuracy test compares mamba_v1 kernel against PyTorch implementation.
59                                Perf test will generate a NEFF for the PyTorch implementation in local directory
60                                for a manual run of neuron-profile.
61                             """)
62    args = parser.parse_args()
63    return args
64
65
66if __name__ == "__main__":
67    args = parse_args()
68
69    # Toy example
70    batch = 1
71    seq_len = 512
72    channels = 256
73    state_size = 16
74
75    dtype = torch.float32
76
77    device = xm.xla_device()
78
79    delta = torch.ones(batch, channels, seq_len, dtype=dtype, device=device)
80    u = torch.ones(batch, channels, seq_len, dtype=dtype, device=device)
81
82    # For numerical accuracy testing purposes, we choose negative numbers for A on purpose.
83    # Otherwise, the associative scan will integrate too fast and overflow, which would
84    # mask any real numerical issues in our computation.
85    # A negative A will ensure we catch numerical issues when we have them.
86    A = -torch.ones(channels, state_size, dtype=dtype, device=device)
87    B = torch.ones(batch, state_size, seq_len, dtype=dtype, device=device)
88
89    C = torch.ones(batch, state_size, seq_len, dtype=dtype, device=device)
90
91    xm.mark_step()
92    torch_out = mamba_layer(delta, A, B, u, C)
93    xm.mark_step()
94    print(torch_out)

The input tensor shapes are as follows:

  • delta: [batch, channels, seq_len]

  • u: [batch, channels, seq_len]

  • A: [channels, state_size]

  • B: [batch, state_size, seq_len]

  • C: [batch, state_size, seq_len]

The key model parameters are:

  • batch: batch size of the model.

  • seq_len: sequence length of the model.

  • channels: hidden size of a token.

  • state_size: number of model states.

We use [batch=1, seq_len=512, channels = 256, state_size = 16] as a simple test case for initial performance evaluation.

Running the above Python script will compile the PyTorch compute graph using Neuron Compiler and generate a Neuron executable file (NEFF) in the same directory. We can then profile the NEFF on a single NeuronCore using neuron-profiler. Figure below is a screenshot of the profile. We see this initial PyTorch implementation takes 151.83 ms to execute on device.

../../../_images/mamba_torch_ref.png

Fig. 83 Profile of Mamba PyTorch Implementation#

Zooming into a portion of the profile, we notice the compute activities on different engines (TensorE/VectorE/ScalarE/GpSimdE) are quite sparse compared to data movement activities (the qSyncIO0 and qVectorSpillReload rows):

../../../_images/mamba_torch_ref_zoomed.png

Fig. 84 Profile of Mamba PyTorch Implementation (Zoomed-in)#

In this seemingly “memory-bound” execution trace, the achieved DMA throughput is also extremely low, hovering around 0.33% utilization throughout execution. Therefore, we are stressing neither the compute nor the memory subsystem, hinting the workload is running at low efficiency on the NeuronCore. In the rest of this tutorial, we will showcase how to re-write the above computation using NKI to achieve a device execution latency of 172.93 usec , which is a 878x speedup compared to the PyTorch reference implementation.

Mapping Mamba Layer to NeuronCore#

In this section, we will discuss how the computation can be mapped onto the NeuronCore architecture. We will also highlight the importance of choosing appropriate data layouts to achieve good compute efficiency.

Recall we have the following input tensor shapes in device memory:

  • delta: [batch_size, channels, seq_len]

  • u: [batch_size, channels, seq_len]

  • A: [channels, state_size]

  • B: [batch_size, state_size, seq_len]

  • C: [batch_size, state_size, seq_len]

In fact, the above tensor layout has been chosen carefully based on the computation done in NeuronCore, which we will discuss in more detail below.

In Mamba models, both seq_len and channels are typically in the thousands (such as seq_len=16K, channels=4K), while batch_size and state_size are much smaller by 2-3 order of magnitudes (such as batch_size=4, state_size=16). To simplify visualization of computation on multi-dimensional tensors, let’s hold batch and state_size dimension constant and focus on computation per batch per state. Note, the batch_size dimension is considered a fully parallel axis in a Mamba layer, while state_size is only a partial parallel axis where results from different states will be accumulated together.

By extracting batch and state_size dimensions, we get the following input tensor shapes in device memory:

  • delta_i: [channels, seq_len]

  • u_i:     [channels, seq_len]

  • A_i:     [channels]

  • B_i:     [seq_len]

  • C_i:     [seq_len]

Next, let’s visualize the data flow and computation using 2D matrices or vectors step-by-step.

Step 1: Element-wise multiplication of delta_i and A_i#

We have the following PyTorch reference code for Step 1:

# delta[batch, channels, seq_len]
# A    [channels, state_size]
delta[:, :, None, :] * A[None, :, :, None]

# Holding batch and state_size constant
# delta_i: [channels, seq_len]
# A_i:     [channels]
delta_i[:, :] * A_i[:]

After the above transformation, the multiplication between delta_i and A_i involves a broadcasting across the seq_len dimension of delta_i. In NKI, free-dimension broadcast can often be folded into the actual computation instruction at no additional performance cost, while partition-dim broadcast often requires a separate instruction on TensorE (see TensorE alternative use case in Trainium/Inferentia2 Architecture Guide). As a result, we have two options for executing Step 1.

Option 1: Map ``seq_len`` to free dimension. Element-wise multiplication of delta_i and A_i on NeuronCore can be done through nisa.tensor_scalar on either VectorE or ScalarE, which automatically broadcast A_i along the free dimension to match the seq_len dimension in A_i.

Note, the channels dimension is mapped to SBUF partition dimension. Since the input channels dimension has a size of 256 in our initial setup, which exceeds the architectural limitation of nl.tile_size.pmax=128 , we must tile delta_i in the channels dimension (tiled dimension denoted as channels_tiled) and feed one tile into nisa.tensor_scalar at a time. Figure below illustrates the computation done for Option 1.

../../../_images/mamba_step1_opt1.png

Fig. 85 Step 1, Option 1: nisa.tensor_scalar#

As an example, the associated NKI code for batch i_batch, state i_state and tile i_tile_channels in channels is:

# Input shape in device memory matches the computation layout
# Device memory layout:
# delta_i: [channels, seq_len]
# A_i:     [channels]

# Computation layout in SBUF:
# delta_i: [par_dim(channels), seq_len]
# A_i:     [par_dim(channels)]

deltaA_i = nisa.tensor_scalar(delta_i, op0=nl.multiply, operand0=A_i)

Note, with this compute layout option, the delta_i tensor shape [channels, seq_len] in device memory can be loaded into SBUF efficiently with seq_len as the free dimension and fed into VectorE/ScalarE for computation. No extra transposes are needed.

Option 2: Map ``seq_len`` to partition dimension. Alternatively, if we choose a transposed layout for delta_i in SBUF for computation, we will need a partition-dimension broadcast of A_i using a separate instruction on TensorE (A_i.broadcast_to(...)) and then a nisa.tensor_tensor operation between delta_i and the broadcast A_i on VectorE. As a reminder, we need to tile the seq_len dimension to meet the tile size constraint nl.tile_size.pmax=128. Figure below illustrates the computation done for Option 2.

../../../_images/mamba_step1_opt2.png

Fig. 86 Step 1, Option 2: p-dim broadcast + nisa.tensor_tensor#

The associated NKI code is as follows:

# Input shape in device memory does NOT match the computation layout
# Device memory layout:
# delta_i: [channels, seq_len]
# A_i:     [channels]

# Computation layout in SBUF:
# delta_i: [par_dim(seq_len_tiled), channels]
# A_i:     [par_dim(1), channels]

A_i_bcast = A_i.broadcast_to((nl.tile_size.pmax, channels))
deltaA_i = nisa.tensor_tensor(delta_i, A_i_bcast, op=ml.multiply)

Assuming the same delta_i device memory layout [channels, seq_len], before performing the nisa.tensor_tensor instruction, we will need to either:

  • Do a regular load of delta_i into SBUF using nl.load and an explicit transpose on the loaded delta_i using nl.transpose to make seq_len lie in the free dimension, or

  • Do a transposed load of delta_i using nl.load_transpose2d, which is significantly less efficient in memory bandwidth usage compared to nl.load

If Option2 was chosen as the compute layout, we would have incentives to define the delta input tensor shape as [seq_len, channels] in device memory instead.

From computation perspectives, Option 2 is less efficient than Option 1 because:

  1. Option 2 needs an extra TensorE instruction performing partition dimension broadcast.

  2. nisa.tensor_tensor is 2x slower than nisa.tensor_scalar for our input data type FP32 (see API doc for instruction cost estimates).

Therefore, for Step 1 only, Option 1 is the winner compared to Option 2. Let’s continue with the rest of the steps to see if we need to revise this selection due to surrounding operator layout preferences.

Step 2: Exponential of deltaA_i.#

Step 2 is evaluating exponential on deltaA_i from the previous step:

torch.exp(...)

In NeuronCore, evaluating an exponential function on a tensor is considered a scalar operation, which runs on ScalarE. This operation can be invoked through nl.exp or nisa.activation. However, ScalarE is able to perform a “pipelined multiply-add” on the input before evaluating a non-linear function (detail see Trainium/Inferentia2 Architecture Guide). In other words, we can fold Step 1 (Option 1) nisa.tensor_scalar and Step 2 into a single ScalarE instruction at no additional cost. This functionality is only exposed in the nisa.activation API. This folding is not feasible if we chose Option 2 nisa.tensor_tensor in Step 1. Figure below illustrates our new execution plan to combine Step 1 and 2 into nisa.activation :

../../../_images/mamba_step2.png

Fig. 87 Step 1&2: nisa.activation#

The associated NKI code is as follows:

# Input shape in device memory matches the computation layout
deltaA_i = nisa.activation(op=nl.exp, data=delta_i, scale=A_i)

Step 3: Element-wise multiplication of delta_i, B_i and u_i.#

PyTorch reference code for Step 3 is:

# delta[batch, channels, seq_len]
# B:   [batch, state_size, seq_len]
# u:   [batch, channels, seq_len]
delta[:, :, None, :] * B[:, None, :, :] * u[:, :, None, :]

# Holding batch and state_size constant
# delta_i: [channels, seq_len]
# B_i:     [seq_len]
# u_i:     [channels, seq_len]
delta_i[:, :] * B_i[None, :] * u_i[:, :]

This step involves similar compute layout and instruction choices as Step 1:

  • channels is either partition or free dimension for both delta_i and u_i

  • multiplication with B_i is either through nisa.tensor_tensor or nisa.tensor_scalar

Since we preferred Step 1 to consume delta_i using channels as the partition dimension in previous steps, it is wise to follow the same layout choice here for delta_i to avoid any transposes. Given this layout choice, the multiplication with B_i will have to be a nisa.tensor_tensor. Figure below visualizes the computation in Step 3:

../../../_images/mamba_step3.png

Fig. 88 Step 3: p-dim broadcast + 2x nisa.tensor_tensor#

The associated NKI code is as follows:

# Input shape in device memory does NOT match the computation layout
# Device memory layout:
# delta_i: [channels, seq_len]
# u_i:     [channels, seq_len]
# B_i:     [seq_len]

# Computation layout in SBUF:
# delta_i: [par_dim(channels_tiled), seq_len]
# u_i:     [par_dim(channels_tiled), seq_len]
# B_i:     [par_dim(1), seq_len]

deltaU_i = nisa.tensor_tensor(delta_i, u_i, op=ml.multiply)
B_i_bcast = B_i.broadcast_to((nl.tile_size.pmax, seq_len))
deltaBu_i = nisa.tensor_tensor(deltaU_i, B_i_bcast, op=ml.multiply)

Step 4: Associative scan between deltaA_i and deltaBu_i#

In this step, we use an associative scan operator between deltaA and deltaBu to aggregate information across time sequentially (sequence length, e.g. sequence of tokens), from the past to the present. Here is a PyTorch reference implementation:

# deltaA:   [batch_size, channels, state_size, seq_len]
# deltaB_u: [batch_size, channels, state_size, seq_len]
out = torch.empty(batch_size, channels, state_size, seq_len,
                  device=deltaA.device, dtype=deltaA.dtype)

for i in range(seq_len):
    # starting state is 0
    prev_state = out[..., i - 1] if i > 0 else 0
    # multiply deltaA by the previous time step state and then add deltaB_u
    out[..., i] = deltaA[..., i] * prev_state + deltaB_u[..., i]

By holding batch and state_size dimensions constant, we get deltaA_i and deltaBu_i both with [channels_tiled, seq_len], where channels_tiled is the partition dimension. The associative scan between these two tile shapes can be implemented in NKI naively through the following loop:

scan_i = nl.ndarray((channels_tiled, seq_len), ...)

# Peeling the first iteration out, which is
# equivalent to loop iterator dependent control flow within the loop
scan_i[0:channels_tiled, 0] = deltaBu[0:channels_tiled, 0]

for i in nl.sequential_range(seq_len - 1):
   scan_i[0:channels_tiled, i+1] =    deltaA_i[0:channels_tiled, i+1] * scan_i[0:channels_tiled, i]
                                    + deltaBu_i[0:channels_tiled, i+1]

Within the loop, the current implementation invokes one instruction for multiplication and another for addition. Since both instructions are performed among tiles of shape [channels_tiled, 1], we can combine these two instructions using nisa.tensor_scalar which supports two operators in a pipelined fashion within an instruction at the same cost as a single operator. Below is a new implementation that could provide 2x speedup compared to the above:

scan_i = nl.ndarray((channels_tiled, seq_len), dtype=deltaA.dtype, buffer=nl.sbuf)
scan_i[0:channels_tiled, 0] = deltaBu[i_p, 0]

for i in nl.sequential_range(seq_len - 1):
   scan_i[0:channels_tiled, i+1] = nisa.tensor_scalar(
        deltaA[0:channels_tiled, i+1],
        op0=nl.multiply,
        operand0=scan_i[0:channels_tiled, i],
        op1=nl.add,
        operand1=deltaBu[0:channels_tiled, i+1])

However, the above loop nest will turn into seq_len many instructions with input tiles that have a single element per partition in SBUF. In addition, every nisa.tensor_scalar instruction has a data dependency on the output of the previous instruction. As discussed in the Trainium/Inferentia2 Architecture Guide, these two traits combined in the instruction sequence is considered extremely inefficient on ScalarE/VectorE, where the static instruction overhead instead of the useful execution time would be dominating the engine timeline.

Conveniently, NKI exposes another instruction nisa.tensor_tensor_scan on VectorE, which can perform the above loop nest in a single instruction by caching the intermediate scan result from the previous time step internally in VectorE without going through SBUF.

scan_i = nisa.tensor_tensor_scan(deltaA_i, deltaBu_i, initial=0,
                                 op0=np.multiply, op1=np.add)

Note, the shape of scan_i is exactly the same as the input deltaA_i/deltaBu_i: [channels_tiled, seq_len].

Step 5: Element-wise multiplication of C_i and scan_i#

The PyTorch reference implementation is:

# scan_res: [batch_size, channels, state_size, seq_len]
# C:        [batch_size, state_size, seq_len]
scanC = C[:, None, :, :] * scan_res

# Holding batch and state constant
# scan_i: [channels_tiled, seq_len]
# C_i:    [seq_len]
scanC_i = C_i[None, :] * scan_i[:, :]

You know the drill - Since channels_tiled is the partition dimension in scan_i from the previous step, we need to perform a partition-dimension broadcast on C_i before invoking nisa.tensor_tensor:

../../../_images/mamba_step5.png

Fig. 89 Step 5: p-dim broadcast + nisa.tensor_tensor#

The corresponding NKI code is:

C_i_bcast = C_i.broadcast((nl.tile_size.pmax, seq_len))
scanC_i = nisa.tensor_tensor(scan_i, C_i_bcast, op=ml.multiply)

Step 6: Accumulation of scanC_i along state_size dimension#

So far in Step 1-5, all the computation is logically parallel across the state_size dimension in a Mamba layer. The next step of computation introduces data dependency along the state_size dimension for the first time. The PyTorch reference implementation is:

# scan_res: [batch_size, channels, state_size, seq_len]
# C:        [batch_size, state_size, seq_len]
# -2 dim is state_size
scanC.sum(dim=-2)

# Holding batch constant only.
# scan_i_states: [channels_tiled, state_size, seq_len]
(scanC_i).sum(dim=-2)

In NKI, we can accumulate the scanC_i results across states element-wise using state_size-1 number of nisa.tensor_tensor instructions:

../../../_images/mamba_step6.png

Fig. 90 Step 6: state_size-1 number of nisa.tensor_tensor#

Since we will be looping over different states, we can also declare an empty accumulation buffer scanC_accum of shape [channels_tiled, seq_len] outside of the loop structure and accumulate into this buffer at the end of the every loop iteration using += operator. The use of a single accumulation buffer avoids allocating memory for scanC_i across all states in SBUF. The corresponding NKI code is:

scanC_accum = nl.zeros(...)

for i_state in nl.affine_range(state_size):
    scanC_i = ...
    scanC_accum += scanC_i

Initial NKI Kernel#

Putting all the pieces together from the previous section, we can arrive at the below kernel implementation mamba_v1:

 1import neuronxcc.nki as nki
 2import neuronxcc.nki.language as nl
 3import neuronxcc.nki.isa as nisa
 4import numpy as np
 5
 6def mamba_v1(delta, u, A, B, C, output):
 7    """Computes the SSM operation in the Mamba model.
 8
 9    :param delta: (batch_size, channels, seq_len)
10    :param u: (batch_size, channels, seq_len)
11    :param A: (channels, state_size)
12    :param B: (batch_size, state_size, seq_len)
13    :param C: (batch_size, state_size, seq_len)
14    :return: (batch_size, channels, seq_len)
15    """
16    batch_size, channels, seq_len = delta.shape
17    _, state_size = A.shape
18
19    # We can relax this using mask paramters in all the NKI API calls
20    assert channels % 128 == 0
21
22    # Map channels to the partition dimension
23    # Tile channels to comply with NKI tile size constraints
24    channel_psize = nl.tile_size.pmax
25    n_channel_tile = channels // channel_psize
26
27    # Most outer loop with batch_size, parallel_for
28    for i_batch in nl.affine_range(batch_size):
29        # partial accumulated scanC result with processed states
30        scanC_accum = nl.zeros((n_channel_tile, nl.par_dim(channel_psize), seq_len), dtype=delta.dtype)
31
32        # Second outer loop with state_size, partial parallel
33        for i_state in nl.affine_range(state_size):
34
35            # Inner loop: tiling channels
36            for i_channel_tile in nl.affine_range(n_channel_tile):
37                channel_start = i_channel_tile * channel_psize
38
39                # Load the relevant tile from delta and A
40                delta_i = nl.load(delta[i_batch, channel_start:channel_start+channel_psize, 0:seq_len])
41                A_i = nl.load(A[channel_start:channel_start+channel_psize, i_state])
42
43                # Step 1&2: Element-wise multiplication of delta_i and A_i and then exponential
44                deltaA = nisa.activation(op=nl.exp, data=delta_i, scale=A_i)
45
46                # Load the relevant tile from u and B
47                u_i = nl.load(u[i_batch, channel_start:channel_start+channel_psize, 0:seq_len])
48                B_i = nl.load(B[i_batch, i_state:i_state+1, 0:seq_len])
49
50                # Step 3: Element-wise multiplication of delta_i, B_i and u_i
51                deltaU = nisa.tensor_tensor(delta_i, u_i, op=nl.multiply)
52                B_i_bcast = B_i.broadcast_to((channel_psize, seq_len))
53                deltaBu = nisa.tensor_tensor(deltaU, B_i_bcast, op=nl.multiply)
54
55                # Step 4: Associative scan between deltaA and deltaBu
56                scan_res = nki.isa.tensor_tensor_scan(deltaA, deltaBu, initial=0,
57                        op0=np.multiply, op1=np.add)
58
59                # Load the relevant tile from C
60                C_i = nl.load(C[i_batch, i_state:i_state+1, 0:seq_len])
61
62                # Step 5: Element-wise multiplication of scan_res and C_i
63                C_i_bcast = C_i.broadcast_to((channel_psize, seq_len))
64                scanC = nisa.tensor_tensor(scan_res, C_i_bcast, op=nl.multiply)
65
66                # Step 6: Accumulation of scanC along state_size dimension
67                # scanC_accum[i_channel_tile, 0:channel_psize, 0:seq_len] = nisa.tensor_tensor(
68                #         scanC_accum[i_channel_tile, 0:channel_psize, 0:seq_len], scanC, op=nl.add)
69                scanC_accum[i_channel_tile, 0:channel_psize, 0:seq_len] += scanC
70
71        # Store scanC_accum for a single batch to output
72        for i_channel_tile in nl.affine_range(n_channel_tile):
73            channel_start = i_channel_tile * channel_psize
74            nl.store(output[i_batch, channel_start:channel_start+channel_psize, 0:seq_len],
75                    scanC_accum[i_channel_tile, 0:channel_psize, 0:seq_len])
76

In the above code example,

  • We have three levels of loop nests. From the outer-most to inner-most:
    • Iterating over batch: Different batch samples perform completely different computation. A tensor is the only input parameter that is shared among batch samples.

    • Iterating over state_size: Different states perform parallel computation until Step 6 as discussed in the previous section. Both delta and u tensors are shared across different states.

    • Iterating over channels: This is the most-inner dimension where we tile the input channels dimension into nl.tile_size.pmax=128 chunks. Both B and C tensors are shared across different channels.

  • The kernel above assumes channels is a multiple of nl.tile_size.pmax=128 . We can relax this by adding a mask parameter in all the NKI API call in the kernel. To simplify the code example, we omit this change. See NKI API Masking for more information.

  • We declare an empty intermediate tensor scanC_accum to hold partial summation from every state.

  • Within the inner loop, we process data for nl.tile_size.pmax=128 channels for one batch sample in one state.
    • We use the slicing syntax to index a tensor. For example, delta[i_batch, channel_start:channel_start+channel_psize, 0:seq_len] grabs data from the input delta tensor for the current range of channels at the current batch sample.

    • Note, in tensor slicing, the first index dimension from the left with a slicing range will be chosen as the partition dimension. When loading B, since we intend to load only one state’s worth of data into one partition of SBUF (discussed in Step 3), we need to explicitly slice the state using: nl.load(B[i_batch, **i_state:i_state+1**, 0:seq_len]). Otherwise, nl.load(B[i_batch, **i_state**, 0:seq_len]) will treat seq_len as the partition dimension, which is not what we planned for in Step 3 and would also trigger a NKI compilation error since seq_len exceeds nl.tile_size.pmax.

    • We accumulate partial scanC_i results into the accumulation buffer using the += operator. This creates a loop-carried dependency for scanC_accum on the i_state loop.

Performance Check#

Let’s re-run neuron-profile on the above NKI kernel:

../../../_images/mamba_v1_profile.png

Fig. 91 Profile of initial Mamba kernel implementation mamba_v1#

Hooray! This NKI kernel implementation now takes 172.93 usec, which is 878x speedup compared to the reference PyTorch implementation. Based on the profile, VectorE is the busiest compute engine in the Mamba layer. This makes sense because the bulk of computation in the kernel is in nisa.tensor_tensor, which can only run on VectorE.

Therefore, our goal is to keep VectorE as busy as possible throughout execution. Note, every NEFF execution involves certain start-up and tear-down overhead. We can use the Selection Summary feature in neuron-profile to find out the percentage of time VectorE is busy during the actual execution period:

../../../_images/mamba_v1_profile_zoomed.png

Fig. 92 Profile of initial Mamba kernel implementation mamba_v1 (zoomed in)#

As indicated by the above profile, VectorE is active over 98.71% of the time, which is rather impressive. However, remember we used small input shapes as a toy example to get started: [batch=1, seq_len=512, channels = 256, n = 16]. Next, let’s increase the channels and seq_len dimensions one by one and observe how VectorE efficiency changes.

Increasing input channels size#

Let’s increase the size of channels by 16x, from 256 to a more realistic value 4096. We obtain the following profile:

../../../_images/mamba_v1_profile_4k_chan.png

Fig. 93 Profile of mamba_v1 kernel with 4K channels#

The new device execution time with increased channels is now 2.34 ms. We can see that VectorE active duration has dropped to 92.16% during the core execution period, compared to 98.71% previously with the toy example. Let’s zoom into an arbitrary region of the profile to see what could be causing VectorE to go idle:

../../../_images/mamba_v1_profile_4k_chan_sem.png

Fig. 94 mamba_v1 kernel blocking on input tensor loading#

By identifying a gap where VectorE is completely idle, we can hover over the first executed instruction after the gap to find out what’s the reason for idleness in the instruction semaphore wait condition. In the above screenshot, the instruction is pending on S[22] to reach a value of 240, which is set by qSyncIO0 activities. This means VectorE has been waiting for input tensors to be loaded before performing more computation. If you hover over qSyncIO0 activities during the VectorE idle period, you can also see the exact input tensor name defined in NKI being loaded in the DMA:

../../../_images/mamba_v1_profile_4k_chan_load_var.png

Fig. 95 DMA loading tensor u in mamba_v1 profile#

We can find similar VectorE gaps through the execution trace. At this point, we can conclude one of the reasons why we have a lower VectorE active time percentage is due to blocking input tensor loading (nl.load) activities in the DMA. Next, let’s spend some time analyzing DMA efficiency.

Zooming out, we can make several observations. First, we see two orange boxes around the qSyncIO0 row. Hovering over the top left corners of the boxes shows two similar performance warnings for loading IO tensors:

../../../_images/mamba_v1_profile_4k_chan_reload.png

Fig. 96 Performance warnings for reloading u and delta tensors#

This indicates we reload both the input u and delta tensors around 7 times. This could be inevitable when we don’t have sufficient on-chip memory (SBUF) to allow full reuse of the input data tensors. However, the profiler shows we are only hitting around 50% capacity usage throughout execution:

../../../_images/mamba_v1_profile_4k_chan_sb.png

Fig. 97 Low SBUF usage#

Therefore, the input tensor reloading is likely not justified, and we should investigate whether we can optimize the NKI kernel to avoid it.

Minimizing data reloading by loop reordering#

To understand why delta and u are being reloaded, let’s revisit our input tensor shapes:

  • delta: [batch_size, channels, seq_len]

  • u:     [batch_size, channels, seq_len]

  • A:     [channels, state_size]

  • B:     [batch_size, state_size, seq_len]

  • C:     [batch_size, state_size, seq_len]

Let’s hold batch_size constant since the majority of input tensors have completely different slices for different batch samples:

  • delta: [channels, seq_len]

  • u:     [channels, seq_len]

  • A:     [channels, state_size]

  • B:     [state_size, seq_len]

  • C:     [state_size, seq_len]

delta and u tensors have the same shape with channels as the outer dimensions, while B and C have the same shape with state_size as the outer dimension. All four of these input tensors have seq_len as the inner dimension. Therefore, we say delta/u is reused across different states, while B/C are reused across different channels. Given this conflicting reuse dimensions, we further say it is more important to prioritize reuse of ``delta/u`` because the expected size of channels is much higher than state_size:

  • state_size is now 16 and typically stay small

  • channels is now 4096 and typically in the thousands

In NKI, we can prioritize delta/u reuse through loop ordering. Recall in the initial NKI kernel implementation, we have the following inner loops:

...
for i_state in nl.affine_range(state_size):
    for i_channel_tile in nl.affine_range(n_channel_tile):
        # step 1-6
...

Since these two loops are executed serially within a single NeuronCore, the loop instances will be unrolled by Neuron Compiler. With the channel dimension in the fastest dimension, we will need to load delta/u across all channels in the first state, and then likely reload them again in the later states due to a large total memory size in delta and u (16MB in this case).

To prioritize reuse of delta/u, we should reorder the above loop nests. To further enforce the reuse, we can hoist the nl.load calls for delta/u outside of the i_state inner loop:

...
for i_channel_tile in nl.affine_range(n_channel_tile):
    delta_i = nl.load(...)
    u_i = nl.load(...)

    for i_state in nl.affine_range(state_size):
        # step 1-6
...

As a side effect of this loop re-ordering, we can also spot a loop fusion opportunity since we have two i_channel_tile loop nests at the same level now:

scanC_accum = nl.zeros((n_channel_tile, nl.par_dim(channel_psize), seq_len), ...)
...

# First i_channel_tile loop
for i_channel_tile in nl.affine_range(n_channel_tile):
    delta_i = nl.load(...)
    u_i = nl.load(...)

    for i_state in nl.affine_range(state_size):
        # step 1-6

# Second i_channel_tile loop
for i_channel_tile in nl.affine_range(n_channel_tile):
    nl.store(..., scanC_accum[i_channel_tile, 0:channel_psize, 0:seq_len])

...

By fusing the two i_channel_tile loop nests into a single loop nest, we can pull the declaration of scanC_accum inside the i_channel_tile loop and further reduce the scanC_accum size requirement by a factor of n_channel_tile :

...

# First i_channel_tile loop
for i_channel_tile in nl.affine_range(n_channel_tile):
    scanC_accum = nl.zeros((nl.par_dim(channel_psize), seq_len), ...)

    delta_i = nl.load(...)
    u_i = nl.load(...)

    for i_state in nl.affine_range(state_size):
        # step 1-6

    nl.store(..., scanC_accum[i_channel_tile, 0:channel_psize, 0:seq_len])

...

Let’s modify our initial NKI kernel implementation accordingly to get mamba_v2:

 1def mamba_v2(delta, u, A, B, C, output):
 2    """Computes the SSM operation in the Mamba model.
 3
 4    :param delta: (batch_size, channels, seq_len)
 5    :param u: (batch_size, channels, seq_len)
 6    :param A: (channels, state_size)
 7    :param B: (batch_size, state_size, seq_len)
 8    :param C: (batch_size, state_size, seq_len)
 9    :return: (batch_size, channels, seq_len)
10    """
11    batch_size, channels, seq_len = delta.shape
12    _, state_size = A.shape
13
14    assert channels % 128 == 0
15
16    # Map channels to the partition dimension
17    # Tile channels to comply with NKI tile size constraints
18    channel_psize = nl.tile_size.pmax
19    n_channel_tile = channels // channel_psize
20
21    # Most outer loop with batch_size, parallel_for
22    for i_batch in nl.affine_range(batch_size):
23
24        # Second outer loop: tiling channels
25        for i_channel_tile in nl.affine_range(n_channel_tile):
26            channel_start = i_channel_tile * channel_psize
27
28            # partial accumulated scanC result with processed states
29            scanC_accum = nl.zeros((nl.par_dim(channel_psize), seq_len), dtype=delta.dtype)
30
31            # Load delta/u once to be reused across states
32            delta_i = nl.load(delta[i_batch, channel_start:channel_start+channel_psize, 0:seq_len])
33            u_i = nl.load(u[i_batch, channel_start:channel_start+channel_psize, 0:seq_len])
34
35            # Inner loop with state_size, partial parallel
36            for i_state in nl.affine_range(state_size):
37                # Load the relevant tile from A
38                A_i = nl.load(A[channel_start:channel_start+channel_psize, i_state])
39
40                # Step 1&2: Element-wise multiplication of delta_i and A_i and then exponential
41                deltaA = nisa.activation(op=nl.exp, data=delta_i, scale=A_i)
42
43                # Load the relevant tile from B
44                B_i = nl.load(B[i_batch, i_state:i_state+1, 0:seq_len])
45
46                # Step 3: Element-wise multiplication of delta_i, B_i and u_i
47                deltaU = nisa.tensor_tensor(delta_i, u_i, op=nl.multiply)
48                B_i_bcast = B_i.broadcast_to((channel_psize, seq_len))
49                deltaBu = nisa.tensor_tensor(deltaU, B_i_bcast, op=nl.multiply)
50
51                # Step 4: Associative scan between deltaA and deltaBu
52                scan_res = nki.isa.tensor_tensor_scan(deltaA, deltaBu, initial=0,
53                        op0=np.multiply, op1=np.add)
54
55                # Load the relevant tile from C
56                C_i = nl.load(C[i_batch, i_state:i_state+1, 0:seq_len])
57
58                # Step 5: Element-wise multiplication of scan_res and C_i
59                C_i_bcast = C_i.broadcast_to((channel_psize, seq_len))
60                scanC = nisa.tensor_tensor(scan_res, C_i_bcast, op=nl.multiply)
61
62                # Step 6: Accumulation of scanC along state_size dimension
63                scanC_accum[0:channel_psize, 0:seq_len] += scanC
64
65            # Store scanC_accum for a single batch to output
66            nl.store(output[i_batch, channel_start:channel_start+channel_psize, 0:seq_len],
67                    scanC_accum[0:channel_psize, 0:seq_len])
68

We recapture the profile for the new kernel implementation:

../../../_images/mamba_v2.png

Fig. 98 Profile of mamba_v2 kernel with loop reordering optimization#

The device execution time is now 1.61 ms, which is a 31% reduction in latency compared to our initial kernel implementation. We can also see VectorE active duration is back up to 99.63% and the performance warnings on input tensor reloading are now gone. In case you are curious, the above loop reordering optimization alone provides around 30% of latency reduction, while the loop fusion optimization contributes the remaining 1% performance boost. This makes sense because the loop reordering addresses our key performance concern around input data reloading, while reducing intermediate tensor size is only a nice-to-have given we were quite low on SBUF usage to begin with.

Increasing input seq_len size#

Next, let’s increase the input seq_len by 16x, from 512 to 8192 and recompile the above NKI kernel. Below is the associated performance profile:

../../../_images/mamba_v2_8K_seqlen.png

Fig. 99 Profile of mamba_v2 kernel with 8K seq_len#

The new profile now takes 53.33 ms, which is 33x longer than the previous profile. VectorE active duration has dropped down to a new low: 58.93%. Compared to the profile captured with a smaller seq_len, we notice new DMA activity rows qSyncSpillReload0 and qVectorSpillReload0 , which are associated with data movement traffic for intermediate data spill from SBUF into device memory or reload back to SBUF. Zooming into a smaller portion of the profile:

../../../_images/mamba_v2_8K_seqlen_zoomed.png

Fig. 100 Poor overlap of computation and data movement#

We can see VectorE enters idle states due to a blocking semaphore wait for qSyncSpillReload0 activities, which indicates the extra spill/reload is indeed degrading overall computation performance. In addition, we can see low SBUF usage peaking at merely 50%. Computation and data movement are also not overlapped properly, leading to low average utilization in both compute engines and DMA throughput in the overall timeline.

Intuitively, increasing seq_len of the kernel increases the active tile sizes of input and intermediate tensors in the free dimension, which could cause severe fragmentations in SBUF and excessive data movements to spill/reload tensors in SBUF. To mitigate these inefficiencies, we must tile the seq_len dimension in our NKI kernel through a new loop level.

Mitigate spilling by tiling seq_len#

We have three key considerations when adding this new loop level:

  1. tile size selection,

  2. loop-carried dependency handling

  3. loop ordering with other loop nests.

Tile size of ``seq_len``. Since previously with seq_len=512 in our toy example, we were able to achieve close to 100% VectorE utilization, let’s set the tile size seq_len_fsize to 512 as a starting point. We can revisit this decision as needed once we obtain a new profile.

Loop-carried dependency. Splitting seq_len into chunks is straightforward for all computation steps except for Step 4. In the associative scan operation, the next loop iteration requires results from the previous iteration for computation. As a result, we will introduce another loop-carried dependency here with the scan tiles. This dependency can be handled through the initial input parameter:

scan_init = nl.zeros((channel_psize, 1), ...)

for i_seq_len_tile in static_range(seq_len // seq_len_fsize):
    scan_i = nisa.tensor_tensor_scan(deltaA, deltaBu, initial=scan_init,
                                          op0=np.multiply, op1=np.add)
    scan_init = scan_i[0:channel_psize, seq_len_fsize-1]

Note, we choose to use static_range instead of affine_range due to the new loop-carried dependencies.

Loop ordering. Recall from our latest NKI kernel implementation, we have the following loop nest:

...
for i_batch in nl.affine_range(batch_size):

    for i_channel_tile in nl.affine_range(n_channel_tile):
        scanC_accum = nl.zeros((nl.par_dim(channel_psize), **seq_len**), ...)

        delta_i = nl.load(delta[i_batch, channel_start:channel_start+channel_psize, 0:**seq_len**])
        u_i = nl.load(u[i_batch, channel_start:channel_start+channel_psize, 0:**seq_len**])

        for i_state in nl.affine_range(state_size):
            A_i = nl.load(A[channel_start:channel_start+channel_psize, i_state])

            B_i = nl.load(B[i_batch, i_state:i_state+1, 0:**seq_len**])
            C_i = nl.load(C[i_batch, i_state:i_state+1, 0:**seq_len**])

            deltaA = ...
            deltaBu = ...
            scanC = ...
            ...
            scanC_accum += ...

         nl.store(..., scanC_accum[i_channel_tile, 0:channel_psize, 0:**seq_len**])
...

Let’s denote the above loop ordering as [batch_size, n_channel_tile, state_size], and our key question here is where to insert seq_len in this list.

Appending seq_len to the above list, that is, making seq_len the new inner-most loop, would involve the least amount of code changes to our current NKI kernel. However, it will lead to the least amount of SBUF usage reduction, since this loop ordering won’t be tiling scanC_accum, delta_i and u_i tensors. Given seq_len=8192 and FP32 data types, these three tensors will occupy 81924B3 = 96 KiB/partition, half of the available SBUF capacity. Let’s go ahead and experiment this loop ordering in a new kernel mamba_v3:

../../../_images/mamba_v3.png

Fig. 101 Profile of mamba_v3 kernel with seq_len tiling optimization#

With the above profile, the kernel now takes 27.8 ms, which is 48% reduction in latency compared to no seq_len tiling. VectorE is now 94.85% active, and we no longer have spilling related DMA activities.

Finally, since the key advantage of Mamba compared to Transformer models is Mamba’s computation and latency should scale linearly with respect to seq_len, instead of quadratically in Transformers, let’s plot the measured kernel latencies across different seq_len up to 8K (what we have optimized so far) and compare it against “perfect latencies” assuming linear scaling from seq_len=512. We evaluate scaling efficiency using perfect latency / measured latency, which is a higher the better metric. Finally, to showcase the importance of the last seq_len tiling optimization for scaling seq_len, we also compare scaling efficiency for mamba_v2 (no seq_len tiling) and mamba_v3 (seq_len tiling).

seq_len

Perfect Latency (ms)

mamba_v2 Measured Latency (ms)

mamba_v2 Scaling Efficiency

mamba_v3 Measured Latency (ms)

mamba_v3 Scaling Efficiency

512

N/A

1.6

N/A

1.6

N/A

1024

3.2

4.4

72.73%

3.3

96.97%

2048

6.4

8.9

71.91%

6.6

96.97%

3072

9.6

13.1

73.28%

10.1

95.05%

4096

12.8

17.6

72.73%

13.3

96.24%

5120

16

23.7

67.51%

17.3

92.49%

6144

19.2

27.5

69.82%

19.6

97.96%

7168

22.4

41.3

54.24%

24.2

92.56%

8192

25.6

52.2

49.04%

27.8

92.09%

The above data shows the last NKI kernel implementation mamba_v3 can reach 90%+ scaling efficiency up to 8K seq_len. To support even larger seq_len, we will need more aggressive tiling by pulling the seq_len loop level further towards the outer-loop level to tile more input/intermediate tensors to keep spilling low and VectorE busy.

Download All Source Code#

Click the links to download source code of the kernels and the testing code discussed in this tutorial.

You can also view the source code in the Github repository nki_samples

Example usage of the scripts:#

Performance mode

Run PyTorch reference implementation to generate a NEFF for profiling:

python3 mamba_torch.py --mode perf

Check performance numbers of mamba_v1/mamba_v2/mamba_v3:

python3 mamba_nki_kernels.py --mode perf --version v1 v2 v3 --batch 1 --seq_len 2048 --channels 512 --state_size 16

Accuracy mode

Check mamba_v1 NKI kernel accuracy against PyTorch implementation:

python3 mamba_torch.py --mode accuracy

Check optimized Mamba kernel (mamba_v2, mamba_v3) accuracy against mamba_v1:

python3 mamba_nki_kernels.py --mode accuracy --version v1 v2 v3 --batch 1 --seq_len 2048 --channels 512 --state_size 16

This document is relevant for: Inf2, Trn1, Trn1n