This document is relevant for: Inf2, Trn1, Trn2

Fused Self Attention#

In this tutorial, we implement a kernel to perform the self attention seen in Stable Diffusion 2.1(SD2.1) from Stability AI. The model is available here. In doing so, we learn about:

  • The NKI syntax and programming model

  • Layout, tiling, and memory management considerations when performing attention computation in NKI

  • Fusion techniques for implementing efficient attention kernel

Background#

In SD2.1, the core computation of the self attention is the following. Given

  • Q: (seqlen, d_head)

  • K: (seqlen, d_head)

  • V: (seqlen, d_head)

where d_head and seqlen represents the head dimension and the seqlen length of the model. The batch dimensions have been removed for simplicity. We would like to compute,

\[ \begin{align}\begin{aligned}S = Q * K.T\\R = \text{softmax}(S)*V\end{aligned}\end{align} \]

When generating images of size 512x512,

\[ \begin{align}\begin{aligned}\text{seqlen} = 4096\\\text{d_head} = 64\end{aligned}\end{align} \]

We assume the data type of all inputs and outputs to be bfloat16.

Naive Algorithm#

../../../_images/attention-tut-qk-matmul.png

Fig. 84 Naively multiple Q and K.T produces a large intermediate matrix#

Fig. 84 shows the scenario if we compute the attention naively. We would first compute S=Q * K.T, which has a size of [4096, 4096]. Since the result is in bfloat16, this intermediate matrix has a size of 4096 * 4096 * 2 bytes = 32MB, far exceeding the total space available in the SBUF(24MB on NeuronCore-v2). This means that we have to spill data from SBUF to HBM after S is computed, and load it back into SBUF when we compute softmax. This leads to lots of data movements between HBM and SBUF, degrading performance.

Fusion to Save SBUF Space#

To avoid exhausting SBUF space, we would like to avoid computing the entirety of the multiplication of Q and K.T at once. One way is to fuse the softmax computation with the second matrix multiplication.

../../../_images/attention-tut-algo.png

Fig. 85 We only need to compute S1 to produce r1#

As shown in the Fig. 85, in order to produce one block of the final result, we only need to compute the highlighted strip S1 to compute the block r1 in the final result.

Recall the TensorEngine on NeuronCore-v2 can process a maximum 128 contraction dimension, and the free dimension of the left hand side matrix has a maximum of 128. In the matrix multiplication S1 = q1 * K.T, as labeled in Fig. 85, the size of the free dimension of q1 should be 128 and S1 has a shape of [128, 4096]. Therefore, the size of S1 is 128 * 4096 * 2 bytes=1MB, which is 32 times smaller than computing the full intermediate matrix.

We can then produce the entire result by looping over the tiles in Q.

Softmax implementation#

../../../_images/attention-tut-tiled.png

Fig. 86 Softmax implementation#

We need to perform softmax activation on Q*K.T, the scheme is shown in the Fig. 86. We first compute partial row-wise maximum on each s_i tile to produce m1, m2... , then we find the global row-wise maximum m of S by computing row-wise maximum on m1, m2... . After subtracting m from s1, s2... , we compute the natural exponential and sum them together to find the row-wise sum rs.

In a regular softmax, we would divide each s1, s2... with rs, however, here we can delay the division to after we compute r1 due to the associativity of scalar-matrix multiplication. Since rs is smaller than r1, we save FLOPS by delaying the division. This is also a major optimization deployed in FlashAttention-v2.

We finally multiply s_i and v_i, and sum them together to get r1. By looping over tiles in Q, we produce the entire result R.

Compute kernel#

  1@nki.jit
  2def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask=False,
  3                                           mixed_precision=True):
  4  """
  5  Fused self attention kernel for small head dimension Stable Diffusion workload, 
  6  simplified for this tutorial. 
  7  
  8  Computes softmax(QK^T)V. Decoder model can optionally include a causal mask 
  9  application. Does not include QKV projection, output projection, dropout, 
 10  residual connection, etc.
 11
 12  This kernel is designed to be used for Stable Diffusion models where the 
 13  d_head is smaller or equal to 128. Assertion is thrown if `d_head` does
 14  not satisfy the requirement.
 15
 16  IO tensor layouts:
 17   - q_ptr: shape   (seq_q, d_head)
 18   - k_ptr: shape   (seq_k, d_head)
 19   - v_ptr: shape   (seq_v, d_head)
 20   - out_ptr: shape (seq_q, d_head)
 21   - We use seq_q and seq_k and seq_v just for clarity, this kernel requires 
 22   seq_q == seq_k == seq_v
 23
 24  IO tensor dtypes:
 25   - This kernel assumes all IO tensors have the same dtype
 26   - If mixed_precision is True, then all Tensor Engine operation will be performed in
 27   bfloat16 and accumulation will be performed in float32. Otherwise the intermediates
 28   will be in the same type as the inputs.
 29  """
 30  # Use q_ref dtype as the intermediate tensor dtype
 31  # Assume all IO tensors have the same dtype
 32  kernel_dtype = q_ref.dtype
 33  pe_in_dt = nl.bfloat16 if mixed_precision else np.float32
 34  assert q_ref.dtype == k_ref.dtype == v_ref.dtype
 35
 36  # Shape checking
 37  seqlen, d_head = q_ref.shape
 38  assert d_head <= 128, "Cannot use this kernel for d_head > 128"
 39  assert tuple(q_ref.shape) == (seqlen, d_head), 'Input shape mismatch!'
 40  assert tuple(k_ref.shape) == (seqlen, d_head), 'Input shape mismatch!'
 41  assert tuple(v_ref.shape) == (seqlen,d_head), \
 42  f'Input shape mismatch! Expected: {(seqlen, d_head)} Actual: {tuple(v_ref.shape)}'
 43  out_ref = nl.ndarray((seqlen, d_head), dtype=q_ref.dtype, buffer=nl.shared_hbm)
 44
 45  # Softmax scaling factor, multiplied onto Q
 46  softmax_scale = 0.125
 47
 48  q_seq_n_tiles, q_seq_tile_size = seqlen // 128, 128
 49  k_seq_n_tiles, k_seq_tile_size = seqlen // 128, 128
 50  # No tiling on d_head dimension since the dimension of d_head fits in SB
 51  d_head_tile_size = d_head
 52  v_seq_n_tiles, v_seq_tile_size = seqlen // 128, 128
 53
 54  ###################################
 55  # Step 1. transpose(tensor_v)
 56  ###################################
 57  # Buffer for v matrix transposed
 58  # Pre-fetch and keep it in SBUF throughout different softmax tiles
 59  trans_v = nl.ndarray((par_dim(v_seq_tile_size), v_seq_n_tiles, d_head), dtype=pe_in_dt)
 60
 61  for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
 62    ip_v = nl.arange(v_seq_tile_size)[:, None]
 63    if_v = nl.arange(d_head_tile_size)[None, :]
 64    trans_v[ip_v, i_k_seq_tile, if_v] = nl.load(
 65      v_ref[i_k_seq_tile * k_seq_tile_size + ip_v, if_v],
 66      dtype=pe_in_dt)
 67
 68  q_local = nl.ndarray((q_seq_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=pe_in_dt)
 69  ip_q = nl.arange(d_head_tile_size)[:, None]
 70  if_q = nl.arange(q_seq_tile_size)[None, :]
 71  for i_q_seq_tile in nl.affine_range(q_seq_n_tiles):
 72    q_local[i_q_seq_tile, ip_q, if_q] = nl.load_transpose2d(
 73      q_ref[i_q_seq_tile * q_seq_tile_size + nl.arange(q_seq_tile_size)[:, None],
 74            nl.arange(d_head_tile_size)[None, :]
 75      ],
 76      dtype=pe_in_dt) * softmax_scale
 77
 78  k_local = nl.ndarray((k_seq_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=pe_in_dt)
 79  ip_k = nl.arange(d_head_tile_size)[:, None]
 80  if_k = nl.arange(k_seq_tile_size)[None, :]
 81  for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
 82    k_local[i_k_seq_tile, ip_k, if_k] = nl.load_transpose2d(
 83      k_ref[i_k_seq_tile * k_seq_tile_size + nl.arange(k_seq_tile_size)[:, None],
 84            nl.arange(d_head_tile_size)[None, :]],
 85      dtype=pe_in_dt)
 86
 87  for i_q_seq_tile in nl.affine_range(q_seq_n_tiles):  # indent = 2
 88    # A SBUF buffer for an independent softmax tile
 89    qk_res_buf = nl.ndarray((par_dim(q_seq_tile_size), seqlen), dtype=kernel_dtype)
 90
 91    neg_max_res = nl.ndarray((par_dim(q_seq_tile_size), k_seq_n_tiles), dtype=kernel_dtype)
 92    ip_max = nl.arange(q_seq_tile_size)[:, None]
 93    if_max = nl.arange(k_seq_n_tiles)[None, :]
 94
 95    # Loop over RHS free of matmul(stationary=tensor_q, moving=tensor_k, contract=d_head)
 96    for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):  # indent = 4
 97
 98      # Since the K^T tile is the RHS, the q_seq_len dimension will be P in the result
 99      # PSUM buffer shape: [q_seq_tile_size P, k_seq_tile_size F]
100      qk_psum = nl.zeros((par_dim(q_seq_tile_size), k_seq_tile_size),
101                         dtype=np.float32, buffer=nl.psum)
102
103      # Tensor indices for accessing qk result in k_seq_tile_size
104      ip_qk = nl.arange(q_seq_tile_size)[:, None]
105      if_qk = nl.arange(k_seq_tile_size)[None, :]
106
107      ##############################################################
108      # Step 2. matmul(stationary=tensor_q, moving=tensor_k, contract=d_head)
109      ##############################################################
110      qk_psum[ip_qk, if_qk] += nisa.nc_matmul(moving=k_local[i_k_seq_tile, ip_k, if_k],
111                                              stationary=q_local[i_q_seq_tile, ip_q, if_q])
112
113      ###################################
114      # Step 3. Apply optional causal mask
115      ###################################
116      if use_causal_mask:
117        # Magic number -9984.0 to replace -inf similar to what neuronx-cc uses
118        qk_res_buf[ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nisa.affine_select(
119          pred=(i_q_seq_tile * q_seq_tile_size + ip_qk >= i_k_seq_tile * k_seq_tile_size + if_qk),
120          on_true_tile=qk_psum[ip_qk, if_qk], on_false_value=-9984.0, dtype=kernel_dtype)
121      else:
122        # Simply send psum result back to sbuf
123        qk_res_buf[ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nl.copy(qk_psum[ip_qk, if_qk],
124                                                                              dtype=kernel_dtype)
125
126      ###################################
127      # Step 4. Softmax
128      ###################################
129      neg_max_res[ip_max, i_k_seq_tile] = nisa.tensor_reduce(
130        np.max, data=qk_res_buf[ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk],
131        axis=(1,), dtype=kernel_dtype, negate=True)
132
133    neg_max_res_final = nisa.tensor_reduce(
134      np.min, data=neg_max_res[ip_max, if_max],
135      axis=(1,), dtype=kernel_dtype, negate=False)
136
137    ip_softmax = nl.arange(q_seq_tile_size)[:, None]
138    if_softmax = nl.arange(seqlen)[None, :]
139    ip_sum_res = nl.arange(q_seq_tile_size)[:, None]
140    if_sum_res = nl.arange(d_head_tile_size)[None, :]
141
142    softmax_res = nl.ndarray((par_dim(q_seq_tile_size), seqlen), dtype=pe_in_dt)
143    sum_divisor = nl.ndarray((par_dim(q_seq_tile_size), d_head_tile_size), dtype=kernel_dtype)
144
145    # Simply use a large tile of seq_len in size since this is a "blocking" instruction
146    # Assuming the compiler will merge exp and reduce_add into a single instruction on ACT
147    exp_res = nisa.activation(np.exp,
148                              data=qk_res_buf[ip_softmax, if_softmax],
149                              bias=neg_max_res_final, scale=1.0)
150
151    sum_res = nisa.tensor_reduce(np.add, data=exp_res, axis=(1,),
152                          dtype=kernel_dtype)
153    softmax_res[ip_softmax, if_softmax] = nl.copy(exp_res, dtype=pe_in_dt)
154
155    sum_reciprocal_broadcast = (1.0 / sum_res).broadcast_to((q_seq_tile_size, d_head_tile_size))
156    sum_divisor[ip_sum_res, if_sum_res] = nl.copy(sum_reciprocal_broadcast, dtype=kernel_dtype)
157
158    # Buffer for transposed softmax results (FP32 in PSUM)
159    trans_softmax_res = nl.ndarray(
160      (par_dim(k_seq_tile_size), k_seq_n_tiles, q_seq_tile_size),
161      dtype=pe_in_dt)
162
163    # Result psum buffer has the hidden dim as P
164    attn_res_psum = nl.zeros((par_dim(d_head_tile_size), q_seq_tile_size),
165                             dtype=np.float32, buffer=nl.psum)
166
167    ip_scores_t = nl.arange(k_seq_tile_size)[:, None]
168    if_scores_t = nl.arange(q_seq_tile_size)[None, :]
169    # Loop over matmul_1 contraction
170    for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
171      ###################################
172      # Step 5. transpose(softmax_res)
173      ###################################
174      ip_scores = nl.arange(q_seq_tile_size)[:, None]
175      if_scores = nl.arange(k_seq_tile_size)[None, :]
176
177      trans_softmax_res[ip_scores_t, i_k_seq_tile, if_scores_t] = nisa.nc_transpose(
178        softmax_res[ip_scores, i_k_seq_tile * k_seq_tile_size + if_scores])
179
180    ip_out = nl.arange(d_head_tile_size)[:, None]
181    if_out = nl.arange(q_seq_tile_size)[None, :]
182    for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
183      ######################################################################
184      # Step 6. matmul_1(stationary=trans_v, moving=trans_softmax_res, contract=seqlen_v=seqlen_k)
185      ######################################################################
186      ip_v_t = nl.arange(k_seq_tile_size)[:, None]
187      if_v_t = nl.arange(d_head_tile_size)[None, :]
188      attn_res_psum[ip_out, if_out] += \
189        nisa.nc_matmul(moving=trans_softmax_res[ip_scores_t, i_k_seq_tile, if_scores_t],
190                       stationary=trans_v[ip_v_t, i_k_seq_tile, if_v_t])
191
192    attn_res_sbuf = nl.copy(attn_res_psum[ip_out, if_out], dtype=kernel_dtype)
193
194    attn_res_div = attn_res_sbuf * nisa.nc_transpose(sum_divisor[ip_sum_res, if_sum_res])
195
196    nl.store(
197      out_ref[i_q_seq_tile * q_seq_tile_size + if_out, ip_out],
198      value=attn_res_div)
199
200  return out_ref

Launching kernel and testing correctness#

Below we write a reference PyTorch implementation of the attention and verify our NKI kernel output against the reference in the same script as the kernel.

 1import torch
 2from torch_xla.core import xla_model as xm
 3
 4from sd_attention_nki_kernels import fused_self_attn_for_SD_small_head_size
 5
 6
 7if __name__ == "__main__":
 8
 9  device = xm.xla_device()
10
11  def cpu_golden_attn(q, k, v):
12      softmax_scale = 0.125
13      q_scaled = q * softmax_scale
14      raw_score = torch.matmul(q_scaled, k.transpose(1, 0))
15      
16      norm_score = torch.nn.functional.softmax(raw_score, dim=-1)
17
18      return torch.matmul(norm_score, v)
19
20  q_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device)
21  k_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device)
22  v_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device)
23
24  output_nki = fused_self_attn_for_SD_small_head_size(q_tensor, k_tensor, v_tensor)
25
26  output_torch = cpu_golden_attn(q_tensor, k_tensor, v_tensor)
27
28  allclose = torch.allclose(output_torch, output_nki, atol=1e-5, rtol=1e-3)
29
30  if allclose:
31    print("NKI and Torch match")
32  else:
33    print("NKI and Torch differ")
34
35  assert allclose

Download All Source Code#

Click the links to download source code of the kernels and the testing code discussed in this tutorial.

You can also view the source code in the GitHub repository nki_samples

Example usage of the scripts:#

Performance mode

Check performance numbers of the attention kernel

python3 sd_attention_nki_kernels.py --mode perf

Accuracy mode

Run PyTorch reference implementation and check correctness:

python3 sd_attention_torch.py

Run barmetal mode and check correctness:

python3 sd_attention_nki_kernels.py --mode accuracy

This document is relevant for: Inf2, Trn1, Trn2