|
| 1 | +import torch |
| 2 | +import torch.nn.functional as F |
| 3 | +import tilelang |
| 4 | +from tilelang.autotuner import * |
| 5 | +import tilelang.language as T |
| 6 | +from functools import partial |
| 7 | + |
| 8 | +num_split = 4 |
| 9 | + |
| 10 | + |
| 11 | +def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_casual, block_M, block_N): |
| 12 | + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) |
| 13 | + shape_q = [batch, seqlen_q, heads, dim] |
| 14 | + shape_kv = [batch, seqlen_kv, heads, dim] |
| 15 | + part_shape = [batch, seqlen_q, heads, num_split, dim] |
| 16 | + dtype = "float16" |
| 17 | + accum_dtype = "float" |
| 18 | + |
| 19 | + @T.macro |
| 20 | + def MMA0( |
| 21 | + K: T.Buffer(shape_kv, dtype), |
| 22 | + Q_shared: T.Buffer([block_M, dim], dtype), |
| 23 | + K_shared: T.Buffer([block_N, dim], dtype), |
| 24 | + acc_s: T.Buffer([block_M, block_N], accum_dtype), |
| 25 | + k: T.int32, |
| 26 | + mid: T.int32, |
| 27 | + hid: T.int32, |
| 28 | + bid: T.int32, |
| 29 | + sid: T.int32, |
| 30 | + ): |
| 31 | + T.copy( |
| 32 | + K[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid + |
| 33 | + (k + 1) * block_N, hid, :], K_shared) |
| 34 | + # TODO: Handle casual split case |
| 35 | + if is_casual: |
| 36 | + for i, j in T.Parallel(block_M, block_N): |
| 37 | + acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0, |
| 38 | + -T.infinity(acc_s.dtype)) |
| 39 | + else: |
| 40 | + T.clear(acc_s) |
| 41 | + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) |
| 42 | + |
| 43 | + @T.macro |
| 44 | + def MMA1( |
| 45 | + V: T.Buffer(shape_kv, dtype), |
| 46 | + V_shared: T.Buffer([block_M, dim], dtype), |
| 47 | + acc_s_cast: T.Buffer([block_M, block_N], dtype), |
| 48 | + acc_o: T.Buffer([block_M, dim], accum_dtype), |
| 49 | + k: T.int32, |
| 50 | + hid: T.int32, |
| 51 | + bid: T.int32, |
| 52 | + sid: T.int32, |
| 53 | + ): |
| 54 | + T.copy( |
| 55 | + V[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid + |
| 56 | + (k + 1) * block_N, hid, :], V_shared) |
| 57 | + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) |
| 58 | + |
| 59 | + @T.macro |
| 60 | + def Softmax( |
| 61 | + acc_s: T.Buffer([block_M, block_N], accum_dtype), |
| 62 | + acc_s_cast: T.Buffer([block_M, block_N], dtype), |
| 63 | + scores_max: T.Buffer([block_M], accum_dtype), |
| 64 | + scores_max_prev: T.Buffer([block_M], accum_dtype), |
| 65 | + scores_scale: T.Buffer([block_M], accum_dtype), |
| 66 | + scores_sum: T.Buffer([block_M], accum_dtype), |
| 67 | + logsum: T.Buffer([block_M], accum_dtype), |
| 68 | + ): |
| 69 | + T.copy(scores_max, scores_max_prev) |
| 70 | + T.fill(scores_max, -T.infinity(accum_dtype)) |
| 71 | + T.reduce_max(acc_s, scores_max, dim=1, clear=False) |
| 72 | + # To do causal softmax, we need to set the scores_max to 0 if it is -inf |
| 73 | + # This process is called Check_inf in FlashAttention3 code, and it only need to be done |
| 74 | + # in the first ceil_div(kBlockM, kBlockN) steps. |
| 75 | + # for i in T.Parallel(block_M): |
| 76 | + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) |
| 77 | + for i in T.Parallel(block_M): |
| 78 | + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) |
| 79 | + for i, j in T.Parallel(block_M, block_N): |
| 80 | + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - |
| 81 | + # max * log_2(e)) This allows the compiler to use the ffma |
| 82 | + # instruction instead of fadd and fmul separately. |
| 83 | + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) |
| 84 | + T.reduce_sum(acc_s, scores_sum, dim=1) |
| 85 | + for i in T.Parallel(block_M): |
| 86 | + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] |
| 87 | + T.copy(acc_s, acc_s_cast) |
| 88 | + |
| 89 | + @T.macro |
| 90 | + def Rescale( |
| 91 | + acc_o: T.Buffer([block_M, dim], accum_dtype), |
| 92 | + scores_scale: T.Buffer([block_M], accum_dtype), |
| 93 | + ): |
| 94 | + for i, j in T.Parallel(block_M, dim): |
| 95 | + acc_o[i, j] *= scores_scale[i] |
| 96 | + |
| 97 | + @T.macro |
| 98 | + def flash_attn_split( |
| 99 | + Q: T.Buffer(shape_q, dtype), |
| 100 | + K: T.Buffer(shape_kv, dtype), |
| 101 | + V: T.Buffer(shape_kv, dtype), |
| 102 | + glse: T.Buffer([batch, heads, num_split, seqlen_q], dtype), |
| 103 | + Output_partial: T.Buffer(part_shape, dtype), |
| 104 | + ): |
| 105 | + with T.Kernel( |
| 106 | + T.ceildiv(seqlen_q, block_M), heads * batch, num_split, |
| 107 | + threads=128) as (bx, by, bz): |
| 108 | + Q_shared = T.alloc_shared([block_M, dim], dtype) |
| 109 | + K_shared = T.alloc_shared([block_N, dim], dtype) |
| 110 | + V_shared = T.alloc_shared([block_N, dim], dtype) |
| 111 | + O_shared = T.alloc_shared([block_M, dim], dtype) |
| 112 | + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) |
| 113 | + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) |
| 114 | + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) |
| 115 | + scores_max = T.alloc_fragment([block_M], accum_dtype) |
| 116 | + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) |
| 117 | + scores_scale = T.alloc_fragment([block_M], accum_dtype) |
| 118 | + scores_sum = T.alloc_fragment([block_M], accum_dtype) |
| 119 | + logsum = T.alloc_fragment([block_M], accum_dtype) |
| 120 | + |
| 121 | + mid = bx |
| 122 | + hid = by % heads |
| 123 | + bid = by // heads |
| 124 | + sid = bz |
| 125 | + |
| 126 | + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) |
| 127 | + T.copy(Q[bid, mid * block_M:(mid + 1) * block_M, hid, :], Q_shared) |
| 128 | + T.fill(acc_o, 0) |
| 129 | + T.fill(logsum, 0) |
| 130 | + T.fill(scores_max, -T.infinity(accum_dtype)) |
| 131 | + |
| 132 | + # TODO: Handle casual split case |
| 133 | + loop_range = ( |
| 134 | + T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv( |
| 135 | + (mid + 1) * block_M, block_N)) if is_casual else T.ceildiv( |
| 136 | + (seqlen_kv // num_split), block_N)) |
| 137 | + |
| 138 | + for k in T.Pipelined(loop_range, num_stages=2): |
| 139 | + MMA0(K, Q_shared, K_shared, acc_s, k, mid, hid, bid, sid) |
| 140 | + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, |
| 141 | + logsum) |
| 142 | + Rescale(acc_o, scores_scale) |
| 143 | + MMA1(V, V_shared, acc_s_cast, acc_o, k, hid, bid, sid) |
| 144 | + for i, j in T.Parallel(block_M, dim): |
| 145 | + acc_o[i, j] /= logsum[i] |
| 146 | + for i in T.Parallel(block_M): |
| 147 | + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale |
| 148 | + T.copy(logsum, glse[bid, hid, sid, mid * block_M:(mid + 1) * block_M]) |
| 149 | + T.copy(acc_o, O_shared) |
| 150 | + T.copy(O_shared, Output_partial[bid, mid * block_M:(mid + 1) * block_M, hid, sid, :]) |
| 151 | + |
| 152 | + @T.macro |
| 153 | + def combine( |
| 154 | + glse: T.Buffer([batch, heads, num_split, seqlen_q], dtype), |
| 155 | + Output_partial: T.Buffer(part_shape, dtype), |
| 156 | + Output: T.Buffer(shape_q, dtype), |
| 157 | + ): |
| 158 | + with T.Kernel(T.ceildiv(seqlen_q, block_M), heads, batch, threads=128) as (bx, by, bz): |
| 159 | + po_local = T.alloc_fragment([block_M, dim], dtype) |
| 160 | + po_shared = T.alloc_shared([block_M, dim], dtype) |
| 161 | + o_accum_local = T.alloc_fragment([block_M, dim], accum_dtype) |
| 162 | + o_shared = T.alloc_shared([block_M, dim], dtype) |
| 163 | + lse_local = T.alloc_fragment([num_split, block_M], dtype) |
| 164 | + lse_local_split = T.alloc_fragment([block_M], accum_dtype) |
| 165 | + lse_logsum_local = T.alloc_fragment([block_M], accum_dtype) |
| 166 | + lse_max_local = T.alloc_fragment([block_M], accum_dtype) |
| 167 | + scale_local = T.alloc_fragment([block_M], accum_dtype) |
| 168 | + |
| 169 | + T.annotate_layout({ |
| 170 | + o_accum_local: T.Fragment(o_accum_local.shape, forward_thread_fn=lambda i, j: i), |
| 171 | + lse_local_split: T.Fragment(lse_local_split.shape, forward_thread_fn=lambda i: i), |
| 172 | + o_shared: tilelang.layout.make_swizzled_layout(o_shared), |
| 173 | + po_shared: tilelang.layout.make_swizzled_layout(po_shared), |
| 174 | + }) |
| 175 | + |
| 176 | + T.clear(lse_logsum_local) |
| 177 | + T.clear(o_accum_local) |
| 178 | + T.copy(glse[ |
| 179 | + bz, |
| 180 | + by, |
| 181 | + :, |
| 182 | + bx * block_M:(bx + 1) * block_M, |
| 183 | + ], lse_local) |
| 184 | + T.reduce_max(lse_local, lse_max_local, dim=0, clear=False) |
| 185 | + for k in T.Pipelined(num_split): |
| 186 | + T.copy(lse_local[k, :], lse_local_split) |
| 187 | + for i in T.Parallel(block_M): |
| 188 | + lse_logsum_local[i] += T.exp2(lse_local_split[i] - lse_max_local[i]) |
| 189 | + for i in T.Parallel(block_M): |
| 190 | + lse_logsum_local[i] = T.log2(lse_logsum_local[i]) + lse_max_local[i] |
| 191 | + for k in T.Pipelined(num_split, num_stages=2): |
| 192 | + T.copy(Output_partial[bz, bx * block_M:(bx + 1) * block_M, by, k, :], po_shared) |
| 193 | + T.copy(po_shared, po_local) |
| 194 | + T.copy(lse_local[k, :], lse_local_split) |
| 195 | + for i in T.Parallel(block_M): |
| 196 | + scale_local[i] = T.exp2(lse_local_split[i] - lse_logsum_local[i]) |
| 197 | + for i, j in T.Parallel(block_M, dim): |
| 198 | + o_accum_local[i, j] += po_local[i, j] * scale_local[i] |
| 199 | + T.copy(o_accum_local, o_shared) |
| 200 | + T.copy(o_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) |
| 201 | + |
| 202 | + @T.prim_func |
| 203 | + def main( |
| 204 | + Q: T.Buffer(shape_q, dtype), |
| 205 | + K: T.Buffer(shape_kv, dtype), |
| 206 | + V: T.Buffer(shape_kv, dtype), |
| 207 | + glse: T.Buffer([batch, heads, num_split, seqlen_q], dtype), |
| 208 | + Output_partial: T.Buffer(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim] |
| 209 | + Output: T.Buffer(shape_q, dtype), |
| 210 | + ): |
| 211 | + flash_attn_split(Q, K, V, glse, Output_partial) |
| 212 | + combine(glse, Output_partial, Output) |
| 213 | + |
| 214 | + return main |
| 215 | + |
| 216 | + |
| 217 | +def ref_program(Q, K, V, glse, Output_partial, casual): |
| 218 | + assert casual is False |
| 219 | + dim = Q.size(-1) |
| 220 | + scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) |
| 221 | + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) |
| 222 | + attention_weights = F.softmax(scores, dim=-1) |
| 223 | + output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) |
| 224 | + return output |
| 225 | + |
| 226 | + |
| 227 | +def reduce_ref(Q, K, V, glse, Output_partial, casual): |
| 228 | + o = torch.empty_like(Output_partial[:, :, :, 0, :]).fill_(0) |
| 229 | + lse_logsum = torch.empty_like(glse[:, :, 0, :]).fill_(0) # [batch, seqlen_q, heads] |
| 230 | + lse_max = glse.max(dim=2, keepdim=False).values |
| 231 | + for ks in range(num_split): |
| 232 | + lse = glse[:, :, ks, :] |
| 233 | + lse_logsum += torch.exp2(lse - lse_max) |
| 234 | + lse_logsum = torch.log2(lse_logsum) + lse_max |
| 235 | + for ks in range(num_split): |
| 236 | + lse = glse[:, :, ks, :] |
| 237 | + scale = torch.exp2(lse - lse_logsum) # [batch, heads, seqlen_q] |
| 238 | + o += Output_partial[:, :, :, ks, :] * scale[:, :, :, None].transpose(1, 2) |
| 239 | + return o.to(torch.float16) |
| 240 | + |
| 241 | + |
| 242 | +def flash_split_ref(Q, K, V, casual): |
| 243 | + # [batch, seqlen_q, heads, dim] |
| 244 | + batch = Q.size(0) |
| 245 | + block_M = Q.size(1) |
| 246 | + nheads = Q.size(2) |
| 247 | + dim = Q.size(3) |
| 248 | + block_N = 128 |
| 249 | + seqlen_kv = K.size(1) |
| 250 | + |
| 251 | + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) |
| 252 | + acc_s = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float) |
| 253 | + acc_s_cast = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float16) |
| 254 | + acc_o = torch.empty((batch, block_M, nheads, dim), device="cuda", dtype=torch.float) |
| 255 | + scores_max = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float) |
| 256 | + scores_max_prev = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float) |
| 257 | + scores_scale = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float) |
| 258 | + scores_sum = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float) |
| 259 | + logsum = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float) |
| 260 | + gacc_o = torch.empty((num_split, batch, block_M, nheads, dim), device="cuda", dtype=torch.float) |
| 261 | + glogsum = torch.empty((num_split, batch, nheads, block_M), device="cuda", dtype=torch.float) |
| 262 | + |
| 263 | + Q_ = Q * scale |
| 264 | + |
| 265 | + for ks in range(num_split): |
| 266 | + acc_o.fill_(0) |
| 267 | + logsum.fill_(0) |
| 268 | + scores_max.fill_(float('-inf')) |
| 269 | + scores_max_prev.fill_(float('-inf')) |
| 270 | + for i in range(int((seqlen_kv // num_split) / block_N)): |
| 271 | + acc_s.fill_(0) |
| 272 | + acc_s = torch.einsum('bqhd,bkhd->bhqk', Q_, |
| 273 | + K[:, (seqlen_kv // num_split) * ks + |
| 274 | + i * block_N:(seqlen_kv // num_split) * ks + |
| 275 | + (i + 1) * block_N, :, :]) # [batch, seqlen, nheads, block_N] |
| 276 | + scores_max_prev = scores_max |
| 277 | + scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM] |
| 278 | + scores_scale = torch.exp2(scores_max_prev - scores_max) |
| 279 | + acc_o *= scores_scale[:, :, :, None].transpose(1, 2) |
| 280 | + acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) |
| 281 | + acc_s_cast = acc_s.to(torch.float16) |
| 282 | + acc_o += torch.einsum( |
| 283 | + 'bhqk,bkhd->bqhd', acc_s_cast, |
| 284 | + V[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + |
| 285 | + (i + 1) * block_N, :, :]) |
| 286 | + scores_sum = acc_s.sum(dim=-1, keepdim=False) |
| 287 | + logsum = logsum * scores_scale + scores_sum |
| 288 | + acc_o /= logsum[:, :, :, None].transpose(1, 2) |
| 289 | + logsum = torch.log2(logsum) + scores_max |
| 290 | + gacc_o[ks, :, :, :, :] = acc_o |
| 291 | + glogsum[ks, :, :, :] = logsum |
| 292 | + |
| 293 | + return glogsum.to(torch.float16).permute(1, 2, 0, |
| 294 | + 3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4) |
| 295 | + |
| 296 | + |
| 297 | +if __name__ == "__main__": |
| 298 | + BATCH, H, Q_CTX, KV_CTX, D_HEAD = 1, 32, 128, 8192, 128 |
| 299 | + casual = False |
| 300 | + flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD |
| 301 | + total_flops = 2 * flops_per_matmul |
| 302 | + if casual: |
| 303 | + total_flops *= 0.5 |
| 304 | + BLOCK_M = 128 |
| 305 | + BLOCK_N = 64 # if D_HEAD <= 128 else 32 |
| 306 | + program = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, casual, BLOCK_M, BLOCK_N) |
| 307 | + ref_program = partial(ref_program, casual=casual) |
| 308 | + mod, params = tilelang.lower(program) |
| 309 | + mod = tilelang.Profiler(mod, params, [5], tilelang.TensorSupplyType.Normal) |
| 310 | + mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) |
| 311 | + print("All checks passed!") |
| 312 | + |
| 313 | + latency = mod.do_bench(ref_program, warmup=500) |
| 314 | + print("{:.2f} ms".format(latency)) |
| 315 | + print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) |
| 316 | + latency = mod.do_bench(mod, n_warmup=10, n_repeat=10, profiler="tvm") |
| 317 | + print("{:.4f} ms".format(latency)) |
| 318 | + print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) |
0 commit comments