This document is relevant for: Inf2
, Trn1
, Trn1n
Fused Self Attention#
In this tutorial, we implement a kernel to perform the self attention seen in Stable Diffusion 2.1(SD2.1) from Stability AI. The model is available here. In doing so, we learn about:
The NKI syntax and programming model
Layout, tiling, and memory management considerations when performing attention computation in NKI
Fusion techniques for implementing efficient attention kernel
Background#
In SD2.1, the core computation of the self attention is the following. Given
Q: (seqlen, d_head)
K: (seqlen, d_head)
V: (seqlen, d_head)
where d_head
and seqlen
represents the head dimension and the seqlen length of
the model. The batch dimensions have been removed for simplicity. We would like to compute,
When generating images of size 512x512
,
We assume the data type of all inputs and outputs to be bfloat16
.
Naive Algorithm#
Fig. 80 shows the scenario if we compute the attention naively.
We would first compute S=Q * K.T
, which has a size of [4096, 4096]
. Since the result
is in bfloat16
, this intermediate matrix has a size of 4096 * 4096 * 2 bytes = 32MB
,
far exceeding the total space available in the SBUF(24MB on NeuronCore-v2). This means that we
have to spill data from SBUF to HBM after S is computed, and load it back into SBUF when we compute
softmax. This leads to lots of data movements between HBM and SBUF, degrading performance.
Fusion to Save SBUF Space#
To avoid exausting SBUF space, we would like to avoid computing the entirety of the multication
of Q
and K.T
at once. One way is to fuse the softmax computation with the second
matrix multiplication.
As shown in the Fig. 81, in order to produce one block
of the final result, we only need to compute the highlighted strip S1
to compute
the block r1
in the final result.
Recall the TensorEngine on NeuronCore-v2 can process a maximum 128 contraction dimension,
and the free dimension of the left hand side matrix has a maximum of 128.
In the matrix multiplication S1 = q1 * K.T
, as labelled
in Fig. 81, the size of the free dimension of q1
should
be 128 and S1
has a shape of [128, 4096]
. Therefore, the size of S1
is 128 * 4096 * 2 bytes=1MB
, which is 32 times smaller than computing the full intermediate matrix.
We can then produce the entire result by looping over the tiles in Q.
Softmax implementation#
We need to perform softmax activation on Q*K.T
, the scheme is shown in the Fig. 82.
We first compute partial row-wise maximum on each s_i
tile to produce m1, m2...
, then we find the
global row-wise maximum m
of S
by computing row-wise maximum on m1, m2...
.
After subtracting m
from s1, s2...
, we compute the natural exponential and sum them together to find the
row-wise sum rs
.
In a regular softmax, we would divide each s1, s2...
with rs
, however, here we
can delay the division to after we compute r1 due to the associativity of scalar-matrix multiplication.
Since rs
is smaller than r1
, we save FLOPS by delaying the division. This is also a major
optimization deployed in FlashAttention-v2.
We finally multiply s_i
and v_i
, and sum them togehter to get r1
. By looping
over tiles in Q, we produce the entire result R
.
Compute kernel#
1def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, out_ref, use_causal_mask=False,
2 mixed_percision=True):
3 """
4 Fused self attention kernel for small head dimension Stable Diffusion workload,
5 simplified for this tutorial.
6
7 Computes softmax(QK^T)V. Decoder model can optionally include a causal mask
8 application. Does not include QKV projection, output projection, dropout,
9 residual connection, etc.
10
11 This kernel is designed to be used for Stable Diffusion models where the
12 d_head is smaller or equal to 128. Assertion is thrown if `d_head` does
13 not satisfy the requirement.
14
15 IO tensor layouts:
16 - q_ptr: shape (seq_q, d_head)
17 - k_ptr: shape (seq_k, d_head)
18 - v_ptr: shape (seq_v, d_head)
19 - out_ptr: shape (seq_q, d_head)
20 - We use seq_q and seq_k and seq_v just for clarity, this kernel requires
21 seq_q == seq_k == seq_v
22
23 IO tensor dtypes:
24 - This kernel assumes all IO tensors have the same dtype
25 - If mixed_percision is True, then all Tensor Engine operation will be performed in
26 bfloat16 and accumulation will be performed in float32. Otherwise the intermediates
27 will be in the same type as the inputs.
28 """
29 # Use q_ref dtype as the intermediate tensor dtype
30 # Assume all IO tensors have the same dtype
31 kernel_dtype = q_ref.dtype
32 pe_in_dt = nl.bfloat16 if mixed_percision else np.float32
33 assert q_ref.dtype == k_ref.dtype == v_ref.dtype == out_ref.dtype
34
35 # Shape checking
36 seqlen, d_head = q_ref.shape
37 assert d_head <= 128, "Cannot use this kernel for d_head > 128"
38 assert tuple(q_ref.shape) == (seqlen, d_head), 'Input shape mismatch!'
39 assert tuple(k_ref.shape) == (seqlen, d_head), 'Input shape mismatch!'
40 assert tuple(v_ref.shape) == (seqlen,d_head), \
41 f'Input shape mismatch! Expected: {(seqlen, d_head)} Actual: {tuple(v_ref.shape)}'
42 assert tuple(out_ref.shape) == (seqlen, d_head), 'Output shape mismatch!'
43
44 # Softmax scaling factor, multiplied onto Q
45 softmax_scale = 0.125
46
47 q_seq_n_tiles, q_seq_tile_size = seqlen // 128, 128
48 k_seq_n_tiles, k_seq_tile_size = seqlen // 128, 128
49 # No tiling on d_head dimension since the dimension of d_head fits in SB
50 d_head_tile_size = d_head
51 v_seq_n_tiles, v_seq_tile_size = seqlen // 128, 128
52
53 ###################################
54 # Step 1. transpose(tensor_v)
55 ###################################
56 # Buffer for v matrix transposed
57 # Pre-fetch and keep it in SBUF throughout different softmax tiles
58 trans_v = nl.ndarray((par_dim(v_seq_tile_size), v_seq_n_tiles, d_head), dtype=pe_in_dt)
59
60 for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
61 ip_v = nl.arange(v_seq_tile_size)[:, None]
62 if_v = nl.arange(d_head_tile_size)[None, :]
63 trans_v[ip_v, i_k_seq_tile, if_v] = nl.load(
64 v_ref[i_k_seq_tile * k_seq_tile_size + ip_v, if_v],
65 dtype=pe_in_dt)
66
67 q_local = nl.ndarray((q_seq_n_tiles, par_dim(d_head_tile_size), q_seq_tile_size), dtype=pe_in_dt)
68 ip_q = nl.arange(d_head_tile_size)[:, None]
69 if_q = nl.arange(q_seq_tile_size)[None, :]
70 for i_q_seq_tile in nl.affine_range(q_seq_n_tiles):
71 q_local[i_q_seq_tile, ip_q, if_q] = nl.load_transpose2d(
72 q_ref[i_q_seq_tile * q_seq_tile_size + nl.arange(q_seq_tile_size)[:, None],
73 nl.arange(d_head_tile_size)[None, :]
74 ],
75 dtype=pe_in_dt) * softmax_scale
76
77 k_local = nl.ndarray((k_seq_n_tiles, par_dim(d_head_tile_size), k_seq_tile_size), dtype=pe_in_dt)
78 ip_k = nl.arange(d_head_tile_size)[:, None]
79 if_k = nl.arange(k_seq_tile_size)[None, :]
80 for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
81 k_local[i_k_seq_tile, ip_k, if_k] = nl.load_transpose2d(
82 k_ref[i_k_seq_tile * k_seq_tile_size + nl.arange(k_seq_tile_size)[:, None],
83 nl.arange(d_head_tile_size)[None, :]],
84 dtype=pe_in_dt)
85
86 for i_q_seq_tile in nl.affine_range(q_seq_n_tiles): # indent = 2
87 # A SBUF buffer for an independent softmax tile
88 qk_res_buf = nl.ndarray((par_dim(q_seq_tile_size), seqlen), dtype=kernel_dtype)
89
90 neg_max_res = nl.ndarray((par_dim(q_seq_tile_size), k_seq_n_tiles), dtype=kernel_dtype)
91 ip_max = nl.arange(q_seq_tile_size)[:, None]
92 if_max = nl.arange(k_seq_n_tiles)[None, :]
93
94 # Loop over RHS free of matmul(stationary=tensor_q, moving=tensor_k, contract=d_head)
95 for i_k_seq_tile in nl.affine_range(k_seq_n_tiles): # indent = 4
96
97 # Since the K^T tile is the RHS, the q_seq_len dimension will be P in the result
98 # PSUM buffer shape: [q_seq_tile_size P, k_seq_tile_size F]
99 qk_psum = nl.zeros((par_dim(q_seq_tile_size), k_seq_tile_size),
100 dtype=np.float32, buffer=nl.psum)
101
102 # Tensor indices for accessing qk result in k_seq_tile_size
103 ip_qk = nl.arange(q_seq_tile_size)[:, None]
104 if_qk = nl.arange(k_seq_tile_size)[None, :]
105
106 ##############################################################
107 # Step 2. matmul(stationary=tensor_q, moving=tensor_k, contract=d_head)
108 ##############################################################
109 qk_psum[ip_qk, if_qk] += nisa.nc_matmul(moving=k_local[i_k_seq_tile, ip_k, if_k],
110 stationary=q_local[i_q_seq_tile, ip_q, if_q])
111
112 ###################################
113 # Step 3. Apply optional causal mask
114 ###################################
115 if use_causal_mask:
116 # Magic number -9984.0 to replace -inf similar to what neuronx-cc uses
117 qk_res_buf[ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nisa.affine_select(
118 pred=(i_q_seq_tile * q_seq_tile_size + ip_qk >= i_k_seq_tile * k_seq_tile_size + if_qk),
119 on_true_tile=qk_psum[ip_qk, if_qk], on_false_value=-9984.0, dtype=kernel_dtype)
120 else:
121 # Simply send psum result back to sbuf
122 qk_res_buf[ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk] = nl.copy(qk_psum[ip_qk, if_qk],
123 dtype=kernel_dtype)
124
125 ###################################
126 # Step 4. Softmax
127 ###################################
128 neg_max_res[ip_max, i_k_seq_tile] = nisa.tensor_reduce(
129 np.max, data=qk_res_buf[ip_qk, i_k_seq_tile * k_seq_tile_size + if_qk],
130 axis=(1,), dtype=kernel_dtype, negate=True)
131
132 neg_max_res_final = nisa.tensor_reduce(
133 np.min, data=neg_max_res[ip_max, if_max],
134 axis=(1,), dtype=kernel_dtype, negate=False)
135
136 ip_softmax = nl.arange(q_seq_tile_size)[:, None]
137 if_softmax = nl.arange(seqlen)[None, :]
138 ip_sum_res = nl.arange(q_seq_tile_size)[:, None]
139 if_sum_res = nl.arange(d_head_tile_size)[None, :]
140
141 softmax_res = nl.ndarray((par_dim(q_seq_tile_size), seqlen), dtype=pe_in_dt)
142 sum_divisor = nl.ndarray((par_dim(q_seq_tile_size), d_head_tile_size), dtype=kernel_dtype)
143
144 # Simply use a large tile of seq_len in size since this is a "blocking" instruction
145 # Assuming the compiler will merge exp and reduce_add into a single instruction on ACT
146 exp_res = nisa.activation(np.exp,
147 data=qk_res_buf[ip_softmax, if_softmax],
148 bias=neg_max_res_final, scale=1.0)
149
150 sum_res = nisa.tensor_reduce(np.add, data=exp_res, axis=(1,),
151 dtype=kernel_dtype)
152 softmax_res[ip_softmax, if_softmax] = nl.copy(exp_res, dtype=pe_in_dt)
153
154 sum_reciprocal_broadcast = (1.0 / sum_res).broadcast_to((q_seq_tile_size, d_head_tile_size))
155 sum_divisor[ip_sum_res, if_sum_res] = nl.copy(sum_reciprocal_broadcast, dtype=kernel_dtype)
156
157 # Buffer for transposed softmax results (FP32 in PSUM)
158 trans_softmax_res = nl.ndarray(
159 (par_dim(k_seq_tile_size), k_seq_n_tiles, q_seq_tile_size),
160 dtype=pe_in_dt)
161
162 # Result psum buffer has the hidden dim as P
163 attn_res_psum = nl.zeros((par_dim(d_head_tile_size), q_seq_tile_size),
164 dtype=np.float32, buffer=nl.psum)
165
166 ip_scores_t = nl.arange(k_seq_tile_size)[:, None]
167 if_scores_t = nl.arange(q_seq_tile_size)[None, :]
168 # Loop over matmul_1 contraction
169 for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
170 ###################################
171 # Step 5. transpose(softmax_res)
172 ###################################
173 ip_scores = nl.arange(q_seq_tile_size)[:, None]
174 if_scores = nl.arange(k_seq_tile_size)[None, :]
175
176 trans_softmax_res[ip_scores_t, i_k_seq_tile, if_scores_t] = nisa.nc_transpose(
177 softmax_res[ip_scores, i_k_seq_tile * k_seq_tile_size + if_scores])
178
179 ip_out = nl.arange(d_head_tile_size)[:, None]
180 if_out = nl.arange(q_seq_tile_size)[None, :]
181 for i_k_seq_tile in nl.affine_range(k_seq_n_tiles):
182 ######################################################################
183 # Step 6. matmul_1(stationary=trans_v, moving=trans_softmax_res, contract=seqlen_v=seqlen_k)
184 ######################################################################
185 ip_v_t = nl.arange(k_seq_tile_size)[:, None]
186 if_v_t = nl.arange(d_head_tile_size)[None, :]
187 attn_res_psum[ip_out, if_out] += \
188 nisa.nc_matmul(moving=trans_softmax_res[ip_scores_t, i_k_seq_tile, if_scores_t],
189 stationary=trans_v[ip_v_t, i_k_seq_tile, if_v_t])
190
191 attn_res_sbuf = nl.copy(attn_res_psum[ip_out, if_out], dtype=kernel_dtype)
192
193 attn_res_div = attn_res_sbuf * nisa.nc_transpose(sum_divisor[ip_sum_res, if_sum_res])
194
195 nl.store(
196 out_ref[i_q_seq_tile * q_seq_tile_size + if_out, ip_out],
197 value=attn_res_div)
198
Launching kernel and testing correctness#
Below we write a reference PyTorch implementation of the attention and verify our NKI kernel output against the reference in the same script as the kernel.
1import torch
2from torch_neuronx.xla_impl.ops import nki_jit
3from torch_xla.core import xla_model as xm
4
5from sd_attention_nki_kernels import fused_self_attn_for_SD_small_head_size
6
7
8if __name__ == "__main__":
9
10 device = xm.xla_device()
11
12 def cpu_golden_attn(q, k, v):
13 softmax_scale = 0.125
14 q_scaled = q * softmax_scale
15 raw_score = torch.matmul(q_scaled, k.transpose(1, 0))
16
17 norm_score = torch.nn.functional.softmax(raw_score, dim=-1)
18
19 return torch.matmul(norm_score, v)
20
21 q_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device)
22 k_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device)
23 v_tensor = torch.rand((4096, 64), dtype=torch.float32).to(device=device)
24 output_nki = torch.zeros((4096, 64), dtype=torch.float32).to(device=device)
25
26 nki_func = nki_jit(func=fused_self_attn_for_SD_small_head_size)
27 nki_func(q_tensor, k_tensor, v_tensor, output_nki)
28
29 output_torch = cpu_golden_attn(q_tensor, k_tensor, v_tensor)
30
31 allclose = torch.allclose(output_torch, output_nki, atol=1e-5, rtol=1e-3)
32
33 if allclose:
34 print("NKI and Torch match")
35 else:
36 print("NKI and Torch differ")
37
38 assert allclose
Download All Source Code#
Click the links to download source code of the kernels and the testing code discussed in this tutorial.
Kernel Definition, accuracy testing and performance benchmark using baremetal mode:
sd_attention_nki_kernels.py
Use the kernel in PyTorch:
sd_attention_torch.py
You can also view the source code in the Github repository nki_samples
Example usage of the scripts:#
Performance mode
Check performance numbers of the attention kernel
python3 sd_attention_nki_kernels.py --mode perf
Accuracy mode
Run PyTorch reference implementation and check correctness:
python3 sd_attention_torch.py
Run barmetal mode and check correctness:
python3 sd_attention_nki_kernels.py --mode accuracy
This document is relevant for: Inf2
, Trn1
, Trn1n