|
| 1 | +# Copyright (c) Tile-AI Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | +# |
| 4 | +# Modified to implement FlashAttention-2 forward pass principles. |
| 5 | +# Corrected loop implementation using T.while_loop. |
| 6 | + |
| 7 | +import torch |
| 8 | +import torch.nn.functional as F |
| 9 | +import tilelang |
| 10 | +import tilelang.language as T |
| 11 | +import itertools |
| 12 | +import argparse |
| 13 | +from functools import partial |
| 14 | + |
| 15 | + |
| 16 | +# PyTorch 参考实现保持不变 |
| 17 | +def ref_program(Q, K, V, is_causal, groups=1): |
| 18 | + assert Q.size( |
| 19 | + 2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" |
| 20 | + assert Q.size( |
| 21 | + 2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" |
| 22 | + dim = Q.size(-1) |
| 23 | + K = K.repeat_interleave(groups, dim=2) |
| 24 | + V = V.repeat_interleave(groups, dim=2) |
| 25 | + scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) |
| 26 | + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) |
| 27 | + if is_causal: |
| 28 | + seq_len = Q.size(1) |
| 29 | + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) |
| 30 | + mask = mask.unsqueeze(0).unsqueeze(0) |
| 31 | + scores = scores.masked_fill(mask == 0, float('-inf')) |
| 32 | + attention_weights = F.softmax(scores, dim=-1) |
| 33 | + output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) |
| 34 | + return output |
| 35 | + |
| 36 | + |
| 37 | +def get_v2_configs(): |
| 38 | + """Generates configurations for the autotuner, tailored for FA-2 style parallelism.""" |
| 39 | + block_M = [64, 128, 256] |
| 40 | + block_N = [32, 64, 128] |
| 41 | + threads = [128, 256, 512] |
| 42 | + num_split_q = [32, 64, 128] |
| 43 | + num_stages = [1, 2, 3] |
| 44 | + enable_rasterization = [True] |
| 45 | + k_pack = [2] |
| 46 | + |
| 47 | + valid_configs = [] |
| 48 | + |
| 49 | + for m, n, s, t, stages, r, k in itertools.product(block_M, block_N, num_split_q, threads, |
| 50 | + num_stages, enable_rasterization, k_pack): |
| 51 | + valid_configs.append({ |
| 52 | + "block_M": m, |
| 53 | + "block_N": n, |
| 54 | + "num_split_q": s, |
| 55 | + "threads": t, |
| 56 | + "num_stages": stages, |
| 57 | + "enable_rasterization": r, |
| 58 | + "k_pack": k |
| 59 | + }) |
| 60 | + if not valid_configs: |
| 61 | + valid_configs.append({ |
| 62 | + 'block_M': 64, |
| 63 | + 'block_N': 64, |
| 64 | + 'num_split_q': 64, |
| 65 | + 'threads': 256, |
| 66 | + 'num_stages': 1, |
| 67 | + 'enable_rasterization': True, |
| 68 | + 'k_pack': 2 |
| 69 | + }) |
| 70 | + return valid_configs |
| 71 | + |
| 72 | + |
| 73 | +@tilelang.autotune(configs=get_v2_configs(), cache_input_tensors=True) |
| 74 | +@tilelang.jit(out_idx=[3]) |
| 75 | +def fast_flashattn_v2( |
| 76 | + batch, |
| 77 | + heads, |
| 78 | + seq_len, |
| 79 | + dim, |
| 80 | + is_causal, |
| 81 | + groups, |
| 82 | + block_M: int, |
| 83 | + block_N: int, |
| 84 | + num_split_q: int, |
| 85 | + threads: int, |
| 86 | + num_stages: int, |
| 87 | + enable_rasterization: bool, |
| 88 | + k_pack: int, |
| 89 | +): |
| 90 | + scale = (1.0 / dim)**0.5 * 1.44269504 |
| 91 | + head_kv = heads // groups |
| 92 | + q_shape = [batch, seq_len, heads, dim] |
| 93 | + kv_shape = [batch, seq_len, head_kv, dim] |
| 94 | + dtype = "float16" |
| 95 | + accum_dtype = "float" |
| 96 | + |
| 97 | + v_vec_size = 4 |
| 98 | + |
| 99 | + vec_size = 4 * k_pack |
| 100 | + |
| 101 | + @T.macro |
| 102 | + def compute_block( |
| 103 | + bz, |
| 104 | + by, |
| 105 | + bx, |
| 106 | + Q: T.Tensor(q_shape, dtype), |
| 107 | + K: T.Tensor(kv_shape, dtype), |
| 108 | + V: T.Tensor(kv_shape, dtype), |
| 109 | + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), |
| 110 | + m_i: T.FragmentBuffer([block_M], accum_dtype), |
| 111 | + l_i: T.FragmentBuffer([block_M], accum_dtype), |
| 112 | + ): |
| 113 | + Q_shared = T.alloc_shared([block_M, dim], dtype) |
| 114 | + K_shared = T.alloc_shared([block_N, dim], dtype) |
| 115 | + V_shared = T.alloc_shared([block_N, dim], dtype) |
| 116 | + P_shared = T.alloc_shared([block_M, block_N], dtype) |
| 117 | + |
| 118 | + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) |
| 119 | + m_prev = T.alloc_fragment([block_M], accum_dtype) |
| 120 | + scale_factor = T.alloc_fragment([block_M], accum_dtype) |
| 121 | + |
| 122 | + q_block_offset = bx * block_M |
| 123 | + T.copy( |
| 124 | + Q[bz, q_block_offset:q_block_offset + block_M, by, :], |
| 125 | + Q_shared, |
| 126 | + coalesced_width=vec_size) |
| 127 | + |
| 128 | + loop_end_k = T.ceildiv(q_block_offset + |
| 129 | + block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) |
| 130 | + for k in T.Pipelined(loop_end_k, num_stages=num_stages): |
| 131 | + kv_idx = k * block_N |
| 132 | + T.copy( |
| 133 | + K[bz, kv_idx:kv_idx + block_N, by // groups, :], K_shared, coalesced_width=vec_size) |
| 134 | + T.copy( |
| 135 | + V[bz, kv_idx:kv_idx + block_N, by // groups, :], |
| 136 | + V_shared, |
| 137 | + coalesced_width=v_vec_size) |
| 138 | + |
| 139 | + T.clear(acc_s) |
| 140 | + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, k_pack=k_pack) |
| 141 | + |
| 142 | + if is_causal: |
| 143 | + for i, j in T.Parallel(block_M, block_N): |
| 144 | + acc_s[i, j] = T.if_then_else(q_block_offset + i >= kv_idx + j, acc_s[i, j], |
| 145 | + -T.infinity(acc_s.dtype)) |
| 146 | + |
| 147 | + T.copy(m_i, m_prev) |
| 148 | + T.reduce_max(acc_s, m_i, dim=1, clear=False) |
| 149 | + |
| 150 | + for i in T.Parallel(block_M): |
| 151 | + sf = T.exp2(m_prev[i] * scale - m_i[i] * scale) |
| 152 | + l_i[i] *= sf |
| 153 | + scale_factor[i] = sf |
| 154 | + |
| 155 | + for i, j in T.Parallel(block_M, dim): |
| 156 | + acc_o[i, j] *= scale_factor[i] |
| 157 | + |
| 158 | + for i, j in T.Parallel(block_M, block_N): |
| 159 | + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - m_i[i] * scale) |
| 160 | + |
| 161 | + row_sum = T.alloc_fragment([block_M], accum_dtype) |
| 162 | + T.reduce_sum(acc_s, row_sum, dim=1) |
| 163 | + for i in T.Parallel(block_M): |
| 164 | + l_i[i] += row_sum[i] |
| 165 | + |
| 166 | + T.copy(acc_s, P_shared) |
| 167 | + T.sync_threads() |
| 168 | + |
| 169 | + T.gemm(P_shared, V_shared, acc_o) |
| 170 | + |
| 171 | + # 修复:将宏移至内核外部,以实现清晰的代码结构。 |
| 172 | + @T.macro |
| 173 | + def scale_and_write_back(src_buffer, scale_vector, dest_tensor, bz, by, q_block_offset): |
| 174 | + # 此宏执行融合的缩放和写回操作,这对性能至关重要。 |
| 175 | + for i, j in T.Parallel(block_M, dim): |
| 176 | + dest_tensor[bz, q_block_offset + i, by, j] = src_buffer[i, j] * scale_vector[i] |
| 177 | + |
| 178 | + @T.macro |
| 179 | + def flash_attn_forward_kernel(Q: T.Tensor(q_shape, dtype), K: T.Tensor(kv_shape, dtype), |
| 180 | + V: T.Tensor(kv_shape, dtype), Output: T.Tensor(q_shape, dtype)): |
| 181 | + with T.Kernel(num_split_q, batch * heads, threads=threads) as (b_split, byz_combined): |
| 182 | + T.use_swizzle(10, enable=enable_rasterization) |
| 183 | + |
| 184 | + bz = byz_combined // heads |
| 185 | + by = byz_combined % heads |
| 186 | + |
| 187 | + num_q_blocks = T.ceildiv(seq_len, block_M) |
| 188 | + |
| 189 | + bx = T.alloc_var("int32") |
| 190 | + bx[0] = b_split |
| 191 | + |
| 192 | + with T.While(bx[0] < num_q_blocks): |
| 193 | + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) |
| 194 | + m_i = T.alloc_fragment([block_M], accum_dtype) |
| 195 | + l_i = T.alloc_fragment([block_M], accum_dtype) |
| 196 | + T.fill(acc_o, 0) |
| 197 | + T.fill(m_i, -T.infinity(accum_dtype)) |
| 198 | + T.fill(l_i, 0) |
| 199 | + |
| 200 | + current_bx = bx[0] |
| 201 | + |
| 202 | + compute_block(bz, by, current_bx, Q, K, V, acc_o, m_i, l_i) |
| 203 | + |
| 204 | + l_inv = T.alloc_fragment([block_M], accum_dtype) |
| 205 | + for i in T.Parallel(block_M): |
| 206 | + safe_l = T.if_then_else(l_i[i] > 1e-6, l_i[i], 1.0) |
| 207 | + l_inv[i] = 1.0 / safe_l |
| 208 | + |
| 209 | + # 修复:现在对宏的调用对编译器来说更清晰。 |
| 210 | + q_block_offset = current_bx * block_M |
| 211 | + scale_and_write_back(acc_o, l_inv, Output, bz, by, q_block_offset) |
| 212 | + |
| 213 | + bx[0] = current_bx + num_split_q |
| 214 | + |
| 215 | + @T.prim_func |
| 216 | + def main( |
| 217 | + Q: T.Tensor(q_shape, dtype), |
| 218 | + K: T.Tensor(kv_shape, dtype), |
| 219 | + V: T.Tensor(kv_shape, dtype), |
| 220 | + Output: T.Tensor(q_shape, dtype), |
| 221 | + ): |
| 222 | + flash_attn_forward_kernel(Q, K, V, Output) |
| 223 | + |
| 224 | + return main |
| 225 | + |
| 226 | + |
| 227 | +# main 函数保持不变 |
| 228 | +def main_v2(batch: int = 1, |
| 229 | + heads: int = 8, |
| 230 | + seq_len: int = 4096, |
| 231 | + dim: int = 128, |
| 232 | + is_causal: bool = False, |
| 233 | + groups: int = 1): |
| 234 | + |
| 235 | + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim |
| 236 | + total_flops = 2 * flops_per_matmul |
| 237 | + if is_causal: |
| 238 | + total_flops *= 0.5 |
| 239 | + |
| 240 | + print("Starting autotuning for FlashAttention-V2...") |
| 241 | + kernel = fast_flashattn_v2(batch, heads, seq_len, dim, is_causal, groups=groups) |
| 242 | + print(f"Autotuning finished. Best Configuration: {kernel.config}") |
| 243 | + |
| 244 | + ref_program_processed = partial(ref_program, is_causal=is_causal, groups=groups) |
| 245 | + |
| 246 | + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) |
| 247 | + |
| 248 | + print("Verifying correctness...") |
| 249 | + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) |
| 250 | + print("All checks pass.") |
| 251 | + |
| 252 | + latency = profiler.do_bench(ref_program_processed, warmup=100) |
| 253 | + print(f"Reference (PyTorch): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops") |
| 254 | + |
| 255 | + latency = profiler.do_bench(warmup=100) |
| 256 | + print( |
| 257 | + f"Fast Flash Attention V2 (Tile-lang): {latency:.2f} ms | {total_flops / latency * 1e-9:.2f} TFlops" |
| 258 | + ) |
| 259 | + |
| 260 | + |
| 261 | +if __name__ == "__main__": |
| 262 | + parser = argparse.ArgumentParser() |
| 263 | + parser.add_argument('--batch', type=int, default=1, help='batch size') |
| 264 | + parser.add_argument('--heads', type=int, default=8, help='heads') |
| 265 | + parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') |
| 266 | + parser.add_argument('--dim', type=int, default=128, help='dim') |
| 267 | + parser.add_argument('--is_causal', action='store_true', help='causal') |
| 268 | + parser.add_argument('--groups', type=int, default=1, help='groups') |
| 269 | + args = parser.parse_args() |
| 270 | + main_v2(args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups) |
0 commit comments