|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn.functional as F |
| 6 | +import tilelang |
| 7 | +from tilelang import Profiler |
| 8 | +from tilelang.autotuner import * |
| 9 | +import tilelang.language as T |
| 10 | +import itertools |
| 11 | +import argparse |
| 12 | +from functools import partial |
| 13 | + |
| 14 | + |
| 15 | +def get_configs(): |
| 16 | + block_M = [128] |
| 17 | + block_N = [128] |
| 18 | + num_stages = [2] |
| 19 | + threads = [256] |
| 20 | + _configs = list(itertools.product(block_M, block_N, num_stages, threads)) |
| 21 | + |
| 22 | + configs = [{ |
| 23 | + 'block_M': c[0], |
| 24 | + 'block_N': c[1], |
| 25 | + 'num_stages': c[2], |
| 26 | + 'threads': c[3] |
| 27 | + } for c in _configs] |
| 28 | + return configs |
| 29 | + |
| 30 | + |
| 31 | +def flashattn(batch, heads, seq_len, dim, is_causal, tune=False): |
| 32 | + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) |
| 33 | + shape = [batch, seq_len, heads, dim] |
| 34 | + dtype = "float16" |
| 35 | + accum_dtype = "float" |
| 36 | + |
| 37 | + def kernel_func(block_M, block_N, num_stages, threads): |
| 38 | + |
| 39 | + @T.macro |
| 40 | + def MMA0( |
| 41 | + K: T.Buffer(shape, dtype), |
| 42 | + Q_shared: T.Buffer([block_M, dim], dtype), |
| 43 | + K_shared: T.Buffer([block_N, dim], dtype), |
| 44 | + acc_s: T.Buffer([block_M, block_N], accum_dtype), |
| 45 | + k: T.int32, |
| 46 | + bx: T.int32, |
| 47 | + by: T.int32, |
| 48 | + bz: T.int32, |
| 49 | + ): |
| 50 | + T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) |
| 51 | + if is_causal: |
| 52 | + for i, j in T.Parallel(block_M, block_N): |
| 53 | + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, |
| 54 | + -T.infinity(acc_s.dtype)) |
| 55 | + else: |
| 56 | + T.clear(acc_s) |
| 57 | + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) |
| 58 | + |
| 59 | + @T.macro |
| 60 | + def MMA1( |
| 61 | + V: T.Buffer(shape, dtype), |
| 62 | + V_shared: T.Buffer([block_M, dim], dtype), |
| 63 | + acc_s_cast: T.Buffer([block_M, block_N], dtype), |
| 64 | + acc_o: T.Buffer([block_M, dim], accum_dtype), |
| 65 | + k: T.int32, |
| 66 | + by: T.int32, |
| 67 | + bz: T.int32, |
| 68 | + ): |
| 69 | + T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) |
| 70 | + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) |
| 71 | + |
| 72 | + @T.macro |
| 73 | + def Softmax( |
| 74 | + acc_s: T.Buffer([block_M, block_N], accum_dtype), |
| 75 | + acc_s_cast: T.Buffer([block_M, block_N], dtype), |
| 76 | + scores_max: T.Buffer([block_M], accum_dtype), |
| 77 | + scores_max_prev: T.Buffer([block_M], accum_dtype), |
| 78 | + scores_scale: T.Buffer([block_M], accum_dtype), |
| 79 | + scores_sum: T.Buffer([block_M], accum_dtype), |
| 80 | + logsum: T.Buffer([block_M], accum_dtype), |
| 81 | + ): |
| 82 | + T.copy(scores_max, scores_max_prev) |
| 83 | + T.fill(scores_max, -T.infinity(accum_dtype)) |
| 84 | + T.reduce_max(acc_s, scores_max, dim=1, clear=False) |
| 85 | + # To do causal softmax, we need to set the scores_max to 0 if it is -inf |
| 86 | + # This process is called Check_inf in FlashAttention3 code, and it only need to be done |
| 87 | + # in the first ceil_div(kBlockM, kBlockN) steps. |
| 88 | + # for i in T.Parallel(block_M): |
| 89 | + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) |
| 90 | + for i in T.Parallel(block_M): |
| 91 | + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) |
| 92 | + for i, j in T.Parallel(block_M, block_N): |
| 93 | + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - |
| 94 | + # max * log_2(e)) This allows the compiler to use the ffma |
| 95 | + # instruction instead of fadd and fmul separately. |
| 96 | + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) |
| 97 | + T.reduce_sum(acc_s, scores_sum, dim=1) |
| 98 | + for i in T.Parallel(block_M): |
| 99 | + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] |
| 100 | + T.copy(acc_s, acc_s_cast) |
| 101 | + |
| 102 | + @T.macro |
| 103 | + def Rescale( |
| 104 | + acc_o: T.Buffer([block_M, dim], accum_dtype), |
| 105 | + scores_scale: T.Buffer([block_M], accum_dtype), |
| 106 | + ): |
| 107 | + for i, j in T.Parallel(block_M, dim): |
| 108 | + acc_o[i, j] *= scores_scale[i] |
| 109 | + |
| 110 | + @T.prim_func |
| 111 | + def main( |
| 112 | + Q: T.Buffer(shape, dtype), |
| 113 | + K: T.Buffer(shape, dtype), |
| 114 | + V: T.Buffer(shape, dtype), |
| 115 | + Output: T.Buffer(shape, dtype), |
| 116 | + ): |
| 117 | + with T.Kernel( |
| 118 | + T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): |
| 119 | + Q_shared = T.alloc_shared([block_M, dim], dtype) |
| 120 | + K_shared = T.alloc_shared([block_N, dim], dtype) |
| 121 | + V_shared = T.alloc_shared([block_N, dim], dtype) |
| 122 | + O_shared = T.alloc_shared([block_M, dim], dtype) |
| 123 | + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) |
| 124 | + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) |
| 125 | + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) |
| 126 | + scores_max = T.alloc_fragment([block_M], accum_dtype) |
| 127 | + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) |
| 128 | + scores_scale = T.alloc_fragment([block_M], accum_dtype) |
| 129 | + scores_sum = T.alloc_fragment([block_M], accum_dtype) |
| 130 | + logsum = T.alloc_fragment([block_M], accum_dtype) |
| 131 | + |
| 132 | + T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) |
| 133 | + T.fill(acc_o, 0) |
| 134 | + T.fill(logsum, 0) |
| 135 | + T.fill(scores_max, -T.infinity(accum_dtype)) |
| 136 | + |
| 137 | + loop_range = ( |
| 138 | + T.min(T.ceildiv(seq_len, block_N), T.ceildiv( |
| 139 | + (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) |
| 140 | + |
| 141 | + for k in T.Pipelined( |
| 142 | + loop_range, |
| 143 | + num_stages=num_stages, |
| 144 | + order=[-1, 0, 3, 1, -1, 2], |
| 145 | + stage=[-1, 0, 0, 1, -1, 1], |
| 146 | + group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]): |
| 147 | + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) |
| 148 | + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, |
| 149 | + scores_sum, logsum) |
| 150 | + Rescale(acc_o, scores_scale) |
| 151 | + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) |
| 152 | + for i, j in T.Parallel(block_M, dim): |
| 153 | + acc_o[i, j] /= logsum[i] |
| 154 | + T.copy(acc_o, O_shared) |
| 155 | + T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) |
| 156 | + |
| 157 | + return main |
| 158 | + |
| 159 | + if tune: |
| 160 | + |
| 161 | + @autotune( |
| 162 | + configs=get_configs(), |
| 163 | + keys=["block_M", "block_N", "num_stages", "threads"], |
| 164 | + warmup=10, |
| 165 | + rep=10) |
| 166 | + @jit( |
| 167 | + out_idx=[3], |
| 168 | + supply_type=tilelang.TensorSupplyType.Integer, |
| 169 | + ref_prog=None, |
| 170 | + profiler="auto") |
| 171 | + def kernel(block_M=None, block_N=None, num_stages=None, threads=None): |
| 172 | + return kernel_func(block_M, block_N, num_stages, threads) |
| 173 | + |
| 174 | + return kernel() |
| 175 | + else: |
| 176 | + |
| 177 | + def kernel(block_M, block_N, num_stages, threads): |
| 178 | + return kernel_func(block_M, block_N, num_stages, threads) |
| 179 | + |
| 180 | + return kernel |
| 181 | + |
| 182 | + |
| 183 | +def ref_program(Q, K, V, is_causal): |
| 184 | + dim = Q.size(-1) |
| 185 | + scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) |
| 186 | + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) |
| 187 | + if is_causal: |
| 188 | + seq_len = Q.size(1) |
| 189 | + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) |
| 190 | + mask = mask.unsqueeze(0).unsqueeze(0) |
| 191 | + scores = scores.masked_fill(mask == 0, float('-inf')) |
| 192 | + attention_weights = F.softmax(scores, dim=-1) |
| 193 | + output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) |
| 194 | + return output |
| 195 | + |
| 196 | + |
| 197 | +if __name__ == "__main__": |
| 198 | + parser = argparse.ArgumentParser() |
| 199 | + parser.add_argument('--batch', type=int, default=8, help='batch size') |
| 200 | + parser.add_argument('--heads', type=int, default=32, help='heads') |
| 201 | + parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') |
| 202 | + parser.add_argument('--dim', type=int, default=128, help='dim') |
| 203 | + parser.add_argument('--is_causal', action='store_true', help='causal') |
| 204 | + parser.add_argument('--tune', action='store_true', help='tune configs') |
| 205 | + args = parser.parse_args() |
| 206 | + batch, heads, seq_len, dim, is_causal = args.batch, args.heads, args.seq_len, args.dim, args.is_causal |
| 207 | + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim |
| 208 | + total_flops = 2 * flops_per_matmul |
| 209 | + if is_causal: |
| 210 | + total_flops *= 0.5 |
| 211 | + |
| 212 | + if (not args.tune): |
| 213 | + program = flashattn( |
| 214 | + batch, heads, seq_len, dim, is_causal, tune=args.tune)( |
| 215 | + block_M=128, block_N=128, num_stages=2, threads=256) |
| 216 | + ref_program = partial(ref_program, is_causal=is_causal) |
| 217 | + mod, params = tilelang.lower(program) |
| 218 | + mod = Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) |
| 219 | + mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) |
| 220 | + print("All checks pass.") |
| 221 | + latency = mod.do_bench(ref_program, warmup=500) |
| 222 | + print("Ref: {:.2f} ms".format(latency)) |
| 223 | + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) |
| 224 | + latency = mod.do_bench(mod.func, warmup=500) |
| 225 | + print("Tile-lang: {:.2f} ms".format(latency)) |
| 226 | + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) |
| 227 | + else: |
| 228 | + best_latency, best_config, _ = flashattn( |
| 229 | + batch, heads, seq_len, dim, is_causal, tune=args.tune) |
| 230 | + print(f"Best latency: {best_latency}") |
| 231 | + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") |
| 232 | + print(f"Best config: {best_config}") |
0 commit comments