This document is relevant for: Inf2, Trn1, Trn1n

Fused Self Attention#

In this tutorial, we implement a kernel to perform the self attention seen in Stable Diffusion 2.1(SD2.1) from Stability AI. The model is available here. In doing so, we learn about:

  • The NKI syntax and programming model

  • Layout, tiling, and memory management considerations when performing attention computation in NKI

  • Fusion techniques for implementing efficient attention kernel

Background#

In SD2.1, the core computation of the self attention is the following. Given

  • Q: (seqlen, d_head)

  • K: (seqlen, d_head)

  • V: (seqlen, d_head)

where d_head and seqlen represents the head dimension and the seqlen length of the model. The batch dimensions have been removed for simplicity. We would like to compute,

\[ \begin{align}\begin{aligned}S = Q * K.T\\R = \text{softmax}(S)*V\end{aligned}\end{align} \]

When generating images of size 512x512,

\[ \begin{align}\begin{aligned}\text{seqlen} = 4096\\\text{d_head} = 64\end{aligned}\end{align} \]

We assume the data type of all inputs and outputs to be bfloat16.

Naive Algorithm#

../../../_images/attention-tut-qk-matmul.png

Fig. 80 Naively multiple Q and K.T produces a large intermediate matrix#

Fig. 80 shows the scenario if we compute the attention naively. We would first compute S=Q * K.T, which has a size of [4096, 4096]. Since the result is in bfloat16, this intermediate matrix has a size of 4096 * 4096 * 2 bytes = 32MB, far exceeding the total space available in the SBUF(24MB on NeuronCore-v2). This means that we have to spill data from SBUF to HBM after S is computed, and load it back into SBUF when we compute softmax. This leads to lots of data movements between HBM and SBUF, degrading performance.

Fusion to Save SBUF Space#

To avoid exausting SBUF space, we would like to avoid computing the entirety of the multication of Q and K.T at once. One way is to fuse the softmax computation with the second matrix multiplication.

../../../_images/attention-tut-algo.png

Fig. 81 We only need to compute S1 to produce r1#

As shown in the Fig. 81, in order to produce one block of the final result, we only need to compute the highlighted strip S1 to compute the block r1 in the final result.

Recall the TensorEngine on NeuronCore-v2 can process a maximum 128 contraction dimension, and the free dimension of the left hand side matrix has a maximum of 128. In the matrix multiplication S1 = q1 * K.T, as labelled in Fig. 81, the size of the free dimension of q1 should be 128 and S1 has a shape of [128, 4096]. Therefore, the size of S1 is 128 * 4096 * 2 bytes=1MB, which is 32 times smaller than computing the full intermediate matrix.

We can then produce the entire result by looping over the tiles in Q.

Softmax implementation#

../../../_images/attention-tut-tiled.png

Fig. 82 Softmax implementation#

We need to perform softmax activation on Q*K.T, the scheme is shown in the Fig. 82. We first compute partial row-wise maximum on each s_i tile to produce m1, m2... , then we find the global row-wise maximum m of S by computing row-wise maximum on m1, m2... . After subtracting m from s1, s2... , we compute the natural exponential and sum them together to find the row-wise sum rs.

In a regular softmax, we would divide each s1, s2... with rs, however, here we can delay the division to after we compute r1 due to the associativity of scalar-matrix multiplication. Since rs is smaller than r1, we save FLOPS by delaying the division. This is also a major optimization deployed in FlashAttention-v2.

We finally multiply s_i and v_i, and sum them togehter to get r1. By looping over tiles in Q, we produce the entire result R.

Compute kernel#

  1def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_causal_mask=False,
  2                                           mixed_percision=True):
  3  """
  4  Fused self attention kernel for small head dimension Stable Diffusion workload, 
  5  simplified for this tutorial. 
  6  
  7  Computes softmax(QK^T)V. Decoder model can optionally include a causal mask 
  8  application. Does not include QKV projection, output projection, dropout, 
  9  residual connection, etc.
 10
 11  This kernel is designed to be used for Stable Diffusion models where the 
 12  d_head is smaller or equal to 128. Assertion is thrown if `d_head` does
 13  not satisfy the requirement.
 14
 15  IO tensor layouts:
 16   - q_ptr: shape   (seq_q, d_head)
 17   - k_ptr: shape   (seq_k, d_head)
 18   - v_ptr: shape   (seq_v, d_head)
 19   - out_ptr: shape (seq_q, d_head)
 20   - We use seq_q and seq_k and seq_v just for clarity, this kernel requires 
 21   seq_q == seq_k == seq_v
 22
 23  IO tensor dtypes:
 24   - This kernel assumes all IO tensors have the same dtype
 25   - If mixed_percision is True, then all Tensor Engine operation will be performed in
 26   bfloat16 and accumulation will be performed in float32. Otherwise the intermediates
 27   will be in the same type as the inputs.
 28  """
 29  # Use q_ref dtype as the intermediate tensor dtype
 30  # Assume all IO tensors have the same dtype
 31  kernel_dtype = q_ref.dtype
 32  pe_in_dt = nl.bfloat16 if mixed_percision else np.float32
 33  assert q_ref.dtype == k_ref.dtype == v_ref.dtype == out_ref.dtype
 34
 35  # Shape checking
 36  seqlen, d_head = q_ref.shape
 37  assert d_head <= 128, "Cannot use this kernel for d_head > 128"
 38  assert tuple(q_ref.shape) == (seqlen, d_head), 'Input shape mismatch!'
 39  assert tuple(k_ref.shape) == (seqlen, d_head), 'Input shape mismatch!'
 40  assert tuple(v_ref.shape) == (seqlen,d_head), \
 41  f'Input shape mismatch! Expected: {(seqlen, d_head)} Actual: {tuple(v_ref.shape)}'
 42  assert tuple(out_ref.shape) == (seqlen, d_head), 'Output shape mismatch!'
 43
 44  # Softmax scaling factor, multiplied onto Q
 45  softmax_scale = 0.125
 46
 47  q_seq_n_tiles, q_seq_tile_size = seqlen // 128, 128
 48  k_seq_n_tiles, k_seq_tile_size = seqlen // 128, 128
 49  # No tiling on d_head dimension since the dimension of d_head fits in SB
 50  d_head_tile_size = d_head
 51  v_seq_n_tiles, v_seq_tile_size = seqlen // 128, 128
 52
 53  ###################################
 54  # Step 1. transpose(tensor_v)
 55  ###################################
 56  # Buffer for v matrix transposed
 57  # Pre-fetch and keep it in SBUF throughout different softmax tiles
 58  trans_v = nl.ndarray((par_dim(v_seq_tile_size), v_seq_n_tiles, d_head), dtype=pe_in_dt)
 59
 60  for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
 61    ip_v = nl.arange(v_seq_tile_size)[:, None]
 62    if_v = nl.arange(d_head_tile_size)[None, :]
 63    trans_v[ip_v, i_k_seq_tile, if_v] = nl.load(
 64      v_ref[i_k_seq_tile * k_seq_tile_size + ip_v, if_v],
 65      dtype=pe_in_dt)
 66
 67  q_local = nl.ndarray((q_seq_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=pe_in_dt)
 68  ip_q = nl.arange(d_head_tile_size)[:, None]
 69  if_q = nl.arange(q_seq_tile_size)[None, :]
 70  for i_q_seq_tile in nl.affine_range(q_seq_n_tiles):
 71    q_local[i_q_seq_tile, ip_q, if_q] = nl.load_transpose2d(
 72      q_ref[i_q_seq_tile * q_seq_tile_size + nl.arange(q_seq_tile_size)[:, None],
 73            nl.arange(d_head_tile_size)[None, :]
 74      ],
 75      dtype=pe_in_dt) * softmax_scale
 76
 77  k_local = nl.ndarray((k_seq_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=pe_in_dt)
 78  ip_k = nl.arange(d_head_tile_size)[:, None]
 79  if_k = nl.arange(k_seq_tile_size)[None, :]
 80  for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
 81    k_local[i_k_seq_tile, ip_k, if_k] = nl.load_transpose2d(
 82      k_ref[i_k_seq_tile * k_seq_tile_size + nl.arange(k_seq_tile_size)[:, None],
 83            nl.arange(d_head_tile_size)[None, :]],
 84      dtype=pe_in_dt)
 85
 86  for i_q_seq_tile in nl.affine_range(q_seq_n_tiles):  # indent = 2
 87    # A SBUF buffer for an independent softmax tile
 88    qk_res_buf = nl.ndarray((par_dim(q_seq_tile_size), seqlen), dtype=kernel_dtype)
 89
 90    neg_max_res = nl.ndarray((par_dim(q_seq_tile_size), k_seq_n_tiles), dtype=kernel_dtype)
 91    ip_max = nl.arange(q_seq_tile_size)[:, None]
 92    if_max = nl.arange(k_seq_n_tiles)[None, :]
 93
 94    # Loop over RHS free of matmul(stationary=tensor_q, moving=tensor_k, contract=d_head)
 95    for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):  # indent = 4
 96
 97      # Since the K^T tile is the RHS, the q_seq_len dimension will be P in the result
 98      # PSUM buffer shape: [q_seq_tile_size P, k_seq_tile_size F]
 99      qk_psum = nl.zeros((par_dim(q_seq_tile_size), k_seq_tile_size),
100                         dtype=np.float32, buffer=nl.psum)
101
102      # Tensor indices for accessing qk result in k_seq_tile_size
103      ip_qk = nl.arange(q_seq_tile_size)[:, None]
104      if_qk = nl.arange(k_seq_tile_size)[None, :]
105
106      ##############################################################
107      # Step 2. matmul(stationary=tensor_q, moving=tensor_k, contract=d_head)
108      ##############################################################
109      qk_psum[ip_qk, if_qk] += nisa.nc_matmul(moving=k_local[i_k_seq_tile, ip_k, if_k],
110                                              stationary=q_local[i_q_seq_tile, ip_q, if_q])
111
112      ###################################
113      # Step 3. Apply optional causal mask
114      ###################################
115      if use_causal_mask:
116        # Magic number -9984.0 to replace -inf similar to what neuronx-cc uses
117        qk_res_buf[ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nisa.affine_select(
118          pred=(i_q_seq_tile * q_seq_tile_size + ip_qk >= i_k_seq_tile * k_seq_tile_size + if_qk),
119          on_true_tile=qk_psum[ip_qk, if_qk], on_false_value=-9984.0, dtype=kernel_dtype)
120      else:
121        # Simply send psum result back to sbuf
122        qk_res_buf[ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nl.copy(qk_psum[ip_qk, if_qk],
123                                                                              dtype=kernel_dtype)
124
125      ###################################
126      # Step 4. Softmax
127      ###################################
128      neg_max_res[ip_max, i_k_seq_tile] = nisa.tensor_reduce(
129        np.max, data=qk_res_buf[ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk],
130        axis=(1,), dtype=kernel_dtype, negate=True)
131
132    neg_max_res_final = nisa.tensor_reduce(
133      np.min, data=neg_max_res[ip_max, if_max],
134      axis=(1,), dtype=kernel_dtype, negate=False)
135
136    ip_softmax = nl.arange(q_seq_tile_size)[:, None]
137    if_softmax = nl.arange(seqlen)[None, :]
138    ip_sum_res = nl.arange(q_seq_tile_size)[:, None]
139    if_sum_res = nl.arange(d_head_tile_size)[None, :]
140
141    softmax_res = nl.ndarray((par_dim(q_seq_tile_size), seqlen), dtype=pe_in_dt)
142    sum_divisor = nl.ndarray((par_dim(q_seq_tile_size), d_head_tile_size), dtype=kernel_dtype)
143
144    # Simply use a large tile of seq_len in size since this is a "blocking" instruction
145    # Assuming the compiler will merge exp and reduce_add into a single instruction on ACT
146    exp_res = nisa.activation(np.exp,
147                              data=qk_res_buf[ip_softmax, if_softmax],
148                              bias=neg_max_res_final, scale=1.0)
149
150    sum_res = nisa.tensor_reduce(np.add, data=exp_res, axis=(1,),
151                          dtype=kernel_dtype)
152    softmax_res[ip_softmax, if_softmax] = nl.copy(exp_res, dtype=pe_in_dt)
153
154    sum_reciprocal_broadcast = (1.0 / sum_res).broadcast_to((q_seq_tile_size, d_head_tile_size))
155    sum_divisor[ip_sum_res, if_sum_res] = nl.copy(sum_reciprocal_broadcast, dtype=kernel_dtype)
156
157    # Buffer for transposed softmax results (FP32 in PSUM)
158    trans_softmax_res = nl.ndarray(
159      (par_dim(k_seq_tile_size), k_seq_n_tiles, q_seq_tile_size),
160      dtype=pe_in_dt)
161
162    # Result psum buffer has the hidden dim as P
163    attn_res_psum = nl.zeros((par_dim(d_head_tile_size), q_seq_tile_size),
164                             dtype=np.float32, buffer=nl.psum)
165
166    ip_scores_t = nl.arange(k_seq_tile_size)[:, None]
167    if_scores_t = nl.arange(q_seq_tile_size)[None, :]
168    # Loop over matmul_1 contraction
169    for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
170      ###################################
171      # Step 5. transpose(softmax_res)
172      ###################################
173      ip_scores = nl.arange(q_seq_tile_size)[:, None]
174      if_scores = nl.arange(k_seq_tile_size)[None, :]
175
176      trans_softmax_res[ip_scores_t, i_k_seq_tile, if_scores_t] = nisa.nc_transpose(
177        softmax_res[ip_scores, i_k_seq_tile * k_seq_tile_size + if_scores])
178
179    ip_out = nl.arange(d_head_tile_size)[:, None]
180    if_out = nl.arange(q_seq_tile_size)[None, :]
181    for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
182      ######################################################################
183      # Step 6. matmul_1(stationary=trans_v, moving=trans_softmax_res, contract=seqlen_v=seqlen_k)
184      ######################################################################
185      ip_v_t = nl.arange(k_seq_tile_size)[:, None]
186      if_v_t = nl.arange(d_head_tile_size)[None, :]
187      attn_res_psum[ip_out, if_out] += \
188        nisa.nc_matmul(moving=trans_softmax_res[ip_scores_t, i_k_seq_tile, if_scores_t],
189                       stationary=trans_v[ip_v_t, i_k_seq_tile, if_v_t])
190
191    attn_res_sbuf = nl.copy(attn_res_psum[ip_out, if_out], dtype=kernel_dtype)
192
193    attn_res_div = attn_res_sbuf * nisa.nc_transpose(sum_divisor[ip_sum_res, if_sum_res])
194
195    nl.store(
196      out_ref[i_q_seq_tile * q_seq_tile_size + if_out, ip_out],
197      value=attn_res_div)
198

Launching kernel and testing correctness#

Below we write a reference PyTorch implementation of the attention and verify our NKI kernel output against the reference in the same script as the kernel.

 1import torch
 2from torch_neuronx.xla_impl.ops import nki_jit
 3from torch_xla.core import xla_model as xm
 4
 5from sd_attention_nki_kernels import fused_self_attn_for_SD_small_head_size
 6
 7
 8if __name__ == "__main__":
 9
10  device = xm.xla_device()
11
12  def cpu_golden_attn(q, k, v):
13      softmax_scale = 0.125
14      q_scaled = q * softmax_scale
15      raw_score = torch.matmul(q_scaled, k.transpose(1, 0))
16      
17      norm_score = torch.nn.functional.softmax(raw_score, dim=-1)
18
19      return torch.matmul(norm_score, v)
20
21  q_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device)
22  k_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device)
23  v_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device)
24  output_nki = torch.zeros((4096, 64), dtype=torch.float32).to(device=device)
25
26  nki_func = nki_jit(func=fused_self_attn_for_SD_small_head_size)
27  nki_func(q_tensor, k_tensor, v_tensor, output_nki)
28
29  output_torch = cpu_golden_attn(q_tensor, k_tensor, v_tensor)
30
31  allclose = torch.allclose(output_torch, output_nki, atol=1e-5, rtol=1e-3)
32
33  if allclose:
34    print("NKI and Torch match")
35  else:
36    print("NKI and Torch differ")
37
38  assert allclose

Download All Source Code#

Click the links to download source code of the kernels and the testing code discussed in this tutorial.

You can also view the source code in the Github repository nki_samples

Example usage of the scripts:#

Performance mode

Check performance numbers of the attention kernel

python3 sd_attention_nki_kernels.py --mode perf

Accuracy mode

Run PyTorch reference implementation and check correctness:

python3 sd_attention_torch.py

Run barmetal mode and check correctness:

python3 sd_attention_nki_kernels.py --mode accuracy

This document is relevant for: Inf2, Trn1, Trn1n