|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn.functional as F |
| 6 | +import tilelang |
| 7 | +from tilelang.profiler import cached |
| 8 | +from tilelang.autotuner import * |
| 9 | +import tilelang.language as T |
| 10 | +import argparse |
| 11 | + |
| 12 | + |
| 13 | +def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N): |
| 14 | + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) |
| 15 | + shape = [batch, seq_len, heads, dim] |
| 16 | + dtype = "float16" |
| 17 | + accum_dtype = "float" |
| 18 | + |
| 19 | + @T.prim_func |
| 20 | + def flash_fwd( |
| 21 | + Q: T.Buffer(shape, dtype), # type: ignore |
| 22 | + K: T.Buffer(shape, dtype), # type: ignore |
| 23 | + V: T.Buffer(shape, dtype), # type: ignore |
| 24 | + Output: T.Buffer(shape, dtype), # type: ignore |
| 25 | + lse: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore |
| 26 | + ): |
| 27 | + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): |
| 28 | + Q_shared = T.alloc_shared([block_M, dim], dtype) |
| 29 | + # Q_local = T.alloc_fragment([block_M, dim], dtype) |
| 30 | + K_shared = T.alloc_shared([block_N, dim], dtype) |
| 31 | + V_shared = T.alloc_shared([block_N, dim], dtype) |
| 32 | + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) |
| 33 | + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) |
| 34 | + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) |
| 35 | + scores_max = T.alloc_fragment([block_M], accum_dtype) |
| 36 | + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) |
| 37 | + scores_scale = T.alloc_fragment([block_M], accum_dtype) |
| 38 | + scores_sum = T.alloc_fragment([block_M], accum_dtype) |
| 39 | + logsum = T.alloc_fragment([block_M], accum_dtype) |
| 40 | + |
| 41 | + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) |
| 42 | + T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) |
| 43 | + T.fill(acc_o, 0) |
| 44 | + T.fill(logsum, 0) |
| 45 | + T.fill(scores_max, -T.infinity(accum_dtype)) |
| 46 | + # T.copy(Q_shared, Q_local) |
| 47 | + # for i, j in T.Parallel(block_M, dim): |
| 48 | + # Q_local[i, j] *= scale |
| 49 | + loop_range = ( |
| 50 | + T.ceildiv( |
| 51 | + (bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N)) |
| 52 | + for k in T.Pipelined(loop_range, num_stages=1): |
| 53 | + T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared) |
| 54 | + if is_casual: |
| 55 | + for i, j in T.Parallel(block_M, block_N): |
| 56 | + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, |
| 57 | + -T.infinity(acc_s.dtype)) |
| 58 | + else: |
| 59 | + T.clear(acc_s) |
| 60 | + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) |
| 61 | + T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) |
| 62 | + T.copy(scores_max, scores_max_prev) |
| 63 | + T.reduce_max(acc_s, scores_max, dim=1, clear=False) |
| 64 | + for i in T.Parallel(block_M): |
| 65 | + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) |
| 66 | + for i, j in T.Parallel(block_M, dim): |
| 67 | + acc_o[i, j] *= scores_scale[i] |
| 68 | + for i, j in T.Parallel(block_M, block_N): |
| 69 | + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) |
| 70 | + T.copy(acc_s, acc_s_cast) |
| 71 | + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) |
| 72 | + T.reduce_sum(acc_s, scores_sum, dim=1) |
| 73 | + for i in T.Parallel(block_M): |
| 74 | + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] |
| 75 | + for i, j in T.Parallel(block_M, dim): |
| 76 | + acc_o[i, j] /= logsum[i] |
| 77 | + T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) |
| 78 | + for i in T.Parallel(block_M): |
| 79 | + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale |
| 80 | + T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M]) |
| 81 | + |
| 82 | + return flash_fwd |
| 83 | + |
| 84 | + |
| 85 | +def flashattn_bwd_preprocess(batch, heads, seq_len, dim): |
| 86 | + dtype = "float16" |
| 87 | + accum_dtype = "float" |
| 88 | + shape = [batch, seq_len, heads, dim] |
| 89 | + blk = 32 |
| 90 | + |
| 91 | + @T.prim_func |
| 92 | + def flash_bwd_prep( |
| 93 | + O: T.Buffer(shape, dtype), # type: ignore |
| 94 | + dO: T.Buffer(shape, dtype), # type: ignore |
| 95 | + Delta: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore |
| 96 | + ): |
| 97 | + with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): |
| 98 | + o = T.alloc_fragment([blk, blk], dtype) |
| 99 | + do = T.alloc_fragment([blk, blk], dtype) |
| 100 | + acc = T.alloc_fragment([blk, blk], accum_dtype) |
| 101 | + delta = T.alloc_fragment([blk], accum_dtype) |
| 102 | + T.clear(acc) |
| 103 | + for k in range(T.ceildiv(dim, blk)): |
| 104 | + T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) |
| 105 | + T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do) |
| 106 | + for i, j in T.Parallel(blk, blk): |
| 107 | + acc[i, j] += o[i, j] * do[i, j] |
| 108 | + T.reduce_sum(acc, delta, 1) |
| 109 | + T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk]) |
| 110 | + |
| 111 | + return flash_bwd_prep |
| 112 | + |
| 113 | + |
| 114 | +def make_dq_layout(dQ): |
| 115 | + # atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment |
| 116 | + return T.Layout(dQ.shape, |
| 117 | + lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2]) |
| 118 | + |
| 119 | + |
| 120 | +def flashattn_bwd_postprocess(batch, heads, seq_len, dim): |
| 121 | + dtype = "float16" |
| 122 | + accum_dtype = "float" |
| 123 | + shape = [batch, seq_len, heads, dim] |
| 124 | + blk = 64 |
| 125 | + |
| 126 | + @T.prim_func |
| 127 | + def flash_bwd_post( |
| 128 | + dQ: T.Buffer(shape, accum_dtype), # type: ignore |
| 129 | + dQ_out: T.Buffer(shape, dtype), # type: ignore |
| 130 | + ): |
| 131 | + with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): |
| 132 | + T.annotate_layout({dQ: make_dq_layout(dQ)}) |
| 133 | + T.copy( |
| 134 | + dQ[bz, bx * blk:(bx + 1) * blk, by, :], |
| 135 | + dQ_out[bz, bx * blk:(bx + 1) * blk, by, :], |
| 136 | + ) |
| 137 | + |
| 138 | + return flash_bwd_post |
| 139 | + |
| 140 | + |
| 141 | +def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N): |
| 142 | + sm_scale = (1.0 / dim)**0.5 |
| 143 | + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) |
| 144 | + shape = [batch, seq_len, heads, dim] |
| 145 | + dtype = "float16" |
| 146 | + accum_dtype = "float" |
| 147 | + |
| 148 | + @T.prim_func |
| 149 | + def flash_bwd( |
| 150 | + Q: T.Buffer(shape, dtype), # type: ignore |
| 151 | + K: T.Buffer(shape, dtype), # type: ignore |
| 152 | + V: T.Buffer(shape, dtype), # type: ignore |
| 153 | + dO: T.Buffer(shape, dtype), # type: ignore |
| 154 | + lse: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore |
| 155 | + Delta: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore |
| 156 | + dQ: T.Buffer(shape, accum_dtype), # type: ignore |
| 157 | + dK: T.Buffer(shape, dtype), # type: ignore |
| 158 | + dV: T.Buffer(shape, dtype), # type: ignore |
| 159 | + ): |
| 160 | + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz): |
| 161 | + K_shared = T.alloc_shared([block_M, dim], dtype) |
| 162 | + dsT_shared = T.alloc_shared([block_M, block_N], dtype) |
| 163 | + # should not store K to local if dim is large |
| 164 | + # K_local = T.alloc_fragment([block_M, dim], dtype) |
| 165 | + # K_local_T = T.alloc_fragment([block_M, dim], dtype) |
| 166 | + # V_local = T.alloc_fragment([block_M, dim], dtype) |
| 167 | + q = T.alloc_shared([block_N, dim], dtype) |
| 168 | + V_shared = T.alloc_shared([block_M, dim], dtype) |
| 169 | + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) |
| 170 | + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) |
| 171 | + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) |
| 172 | + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) |
| 173 | + lse_shared = T.alloc_shared([block_N], accum_dtype) |
| 174 | + delta = T.alloc_shared([block_N], accum_dtype) |
| 175 | + do = T.alloc_shared([block_N, dim], dtype) |
| 176 | + dv = T.alloc_fragment([block_M, dim], accum_dtype) |
| 177 | + dk = T.alloc_fragment([block_M, dim], accum_dtype) |
| 178 | + dq = T.alloc_fragment([block_N, dim], accum_dtype) |
| 179 | + dv_shared = T.alloc_shared([block_N, dim], dtype) |
| 180 | + dk_shared = T.alloc_shared([block_N, dim], dtype) |
| 181 | + |
| 182 | + T.annotate_layout({ |
| 183 | + dQ: make_dq_layout(dQ), |
| 184 | + K_shared: tilelang.layout.make_swizzled_layout(K_shared), |
| 185 | + dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), |
| 186 | + dk_shared: tilelang.layout.make_swizzled_layout(dk_shared), |
| 187 | + }) |
| 188 | + |
| 189 | + T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared) |
| 190 | + T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared) |
| 191 | + T.clear(dv) |
| 192 | + T.clear(dk) |
| 193 | + loop_st = T.floordiv(by * block_M, block_N) if is_casual else 0 |
| 194 | + loop_ed = T.ceildiv(seq_len, block_N) |
| 195 | + for k in T.Pipelined(loop_st, loop_ed, num_stages=2): |
| 196 | + T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q) |
| 197 | + T.clear(qkT) |
| 198 | + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) |
| 199 | + T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared) |
| 200 | + for i, j in T.Parallel(block_M, block_N): |
| 201 | + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) |
| 202 | + if is_casual: |
| 203 | + for i, j in T.Parallel(block_M, block_N): |
| 204 | + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], |
| 205 | + 0) |
| 206 | + T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) |
| 207 | + T.clear(dsT) |
| 208 | + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) |
| 209 | + T.copy(qkT, qkT_cast) |
| 210 | + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) |
| 211 | + |
| 212 | + T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta) |
| 213 | + |
| 214 | + for i, j in T.Parallel(block_M, block_N): |
| 215 | + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale |
| 216 | + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) |
| 217 | + |
| 218 | + T.copy(dsT_cast, dsT_shared) |
| 219 | + T.clear(dq) |
| 220 | + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) |
| 221 | + for i, j in T.Parallel(block_N, dim): |
| 222 | + if k * block_N + i < seq_len: |
| 223 | + T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) |
| 224 | + T.copy(dv, dv_shared) |
| 225 | + T.copy(dk, dk_shared) |
| 226 | + T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :]) |
| 227 | + T.copy(dk_shared, dK[bz, by * block_M:(by + 1) * block_M, bx, :]) |
| 228 | + |
| 229 | + return flash_bwd |
| 230 | + |
| 231 | + |
| 232 | +class _attention(torch.autograd.Function): |
| 233 | + |
| 234 | + @staticmethod |
| 235 | + def forward(ctx, q, k, v, causal): |
| 236 | + BATCH, N_CTX, H, D_HEAD = q.shape |
| 237 | + block_M = 64 |
| 238 | + block_N = 64 if D_HEAD <= 128 else 32 |
| 239 | + mod = cached(flashattn_fwd, [3, 4], BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N) |
| 240 | + o, lse = mod(q, k, v) |
| 241 | + ctx.save_for_backward(q, k, v, o, lse) |
| 242 | + ctx.causal = causal |
| 243 | + return o |
| 244 | + |
| 245 | + @staticmethod |
| 246 | + def backward(ctx, do): |
| 247 | + q, k, v, o, lse = ctx.saved_tensors |
| 248 | + |
| 249 | + def maybe_contiguous(x): |
| 250 | + if x.stride(-1) != 1: |
| 251 | + return x.contiguous() |
| 252 | + return x |
| 253 | + |
| 254 | + do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)] |
| 255 | + block_M = 128 |
| 256 | + block_N = 128 if D_HEAD <= 64 else 32 |
| 257 | + mod_prep = cached(flashattn_bwd_preprocess, [2], BATCH, H, N_CTX, D_HEAD) |
| 258 | + mod_post = cached(flashattn_bwd_postprocess, [1], BATCH, H, N_CTX, D_HEAD) |
| 259 | + delta = mod_prep(o, do) |
| 260 | + mod = cached(flashattn_bwd, [6, 7, 8], BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, |
| 261 | + block_N) |
| 262 | + dq, dk, dv = mod(q, k, v, do, lse, delta) |
| 263 | + dq = mod_post(dq) |
| 264 | + return dq, dk, dv, None |
| 265 | + |
| 266 | + |
| 267 | +attention = _attention.apply |
| 268 | + |
| 269 | + |
| 270 | +def ref_program(Q, K, V, is_causal): |
| 271 | + dim = Q.size(-1) |
| 272 | + scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) |
| 273 | + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) |
| 274 | + if is_causal: |
| 275 | + seq_len = Q.size(1) |
| 276 | + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) |
| 277 | + mask = mask.unsqueeze(0).unsqueeze(0) |
| 278 | + scores = scores.masked_fill(mask == 0, float('-inf')) |
| 279 | + attention_weights = F.softmax(scores, dim=-1) |
| 280 | + output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) |
| 281 | + return output |
| 282 | + |
| 283 | + |
| 284 | +if __name__ == "__main__": |
| 285 | + parser = argparse.ArgumentParser() |
| 286 | + parser.add_argument('--batch', type=int, default=8, help='Batch size') |
| 287 | + parser.add_argument('--h', type=int, default=32, help='Number of heads') |
| 288 | + parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') |
| 289 | + parser.add_argument('--d_head', type=int, default=64, help='Head dimension') |
| 290 | + parser.add_argument('--casual', type=bool, default=False, help='Casual flag') |
| 291 | + args = parser.parse_args() |
| 292 | + BATCH, H, N_CTX, D_HEAD = args.batch, args.h, args.n_ctx, args.d_head |
| 293 | + casual = args.casual |
| 294 | + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD |
| 295 | + total_flops = 5 * flops_per_matmul |
| 296 | + if casual: |
| 297 | + total_flops *= 0.5 |
| 298 | + Q = ( |
| 299 | + torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, |
| 300 | + device="cuda").normal_().requires_grad_()) |
| 301 | + K = torch.empty_like(Q).normal_().requires_grad_() |
| 302 | + V = torch.empty_like(Q).normal_().requires_grad_() |
| 303 | + dO = torch.randn_like(Q) |
| 304 | + O = attention(Q, K, V, casual) |
| 305 | + O.backward(dO, retain_graph=True) |
| 306 | + dQ, Q.grad = Q.grad.clone(), None |
| 307 | + dK, K.grad = K.grad.clone(), None |
| 308 | + dV, V.grad = V.grad.clone(), None |
| 309 | + |
| 310 | + O_ref = ref_program(Q, K, V, casual) |
| 311 | + O_ref.backward(dO, retain_graph=True) |
| 312 | + dQ_ref, Q.grad = Q.grad.clone(), None |
| 313 | + dK_ref, K.grad = K.grad.clone(), None |
| 314 | + dV_ref, V.grad = V.grad.clone(), None |
| 315 | + |
| 316 | + assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) |
| 317 | + assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) |
| 318 | + assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) |
| 319 | + assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) |
| 320 | + |
| 321 | + def run(): |
| 322 | + O_ref.backward(dO, retain_graph=True) |
| 323 | + |
| 324 | + def run1(): |
| 325 | + O.backward(dO, retain_graph=True) |
| 326 | + |
| 327 | + from tilelang.profiler import do_bench |
| 328 | + |
| 329 | + latency = do_bench(run, warmup=500) |
| 330 | + print("torch: {:.2f} ms".format(latency)) |
| 331 | + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) |
| 332 | + latency = do_bench(run1, warmup=500) |
| 333 | + print("tilelang: {:.2f} ms".format(latency)) |
| 334 | + print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) |
0 commit comments