From 7d35ac5b6ef83aa15c70e339d0e0a15e4977e9ee Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 23 Feb 2025 17:39:48 +0000 Subject: [PATCH 1/6] Add DeepSeek MLA decode example with Flash Attention implementation --- examples/deepseek_mla/.gitkeep | 0 examples/deepseek_mla/example_mla_decode.py | 267 ++++++++++++++++++++ 2 files changed, 267 insertions(+) delete mode 100644 examples/deepseek_mla/.gitkeep create mode 100644 examples/deepseek_mla/example_mla_decode.py diff --git a/examples/deepseek_mla/.gitkeep b/examples/deepseek_mla/.gitkeep deleted file mode 100644 index e69de29bb..000000000 diff --git a/examples/deepseek_mla/example_mla_decode.py b/examples/deepseek_mla/example_mla_decode.py new file mode 100644 index 000000000..91ddd2824 --- /dev/null +++ b/examples/deepseek_mla/example_mla_decode.py @@ -0,0 +1,267 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T + +num_split = 4 + + +def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H): + scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) + shape_q = [batch, heads, (dim + pe_dim)] + shape_k = [batch, seqlen_kv, kv_head_num, (dim + pe_dim)] + shape_v = [batch, seqlen_kv, kv_head_num, dim] + shape_o = [batch, heads, dim] + part_shape = [batch, heads, num_split, dim] + dtype = "float16" + accum_dtype = "float" + kv_group_num = heads // kv_head_num + VALID_BLOCK_H = min(block_H, kv_group_num) + assert kv_head_num == 1, "kv_head_num must be 1" + + @T.macro + def flash_attn_split( + Q: T.Buffer(shape_q, dtype), + K: T.Buffer(shape_k, dtype), + V: T.Buffer(shape_v, dtype), + glse: T.Buffer([batch, heads, num_split], dtype), + Output_partial: T.Buffer(part_shape, dtype), + ): + with T.Kernel( + batch, heads // min(block_H, kv_group_num), num_split, threads=128) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, (dim + pe_dim)], dtype) + K_shared = T.alloc_shared([block_N, (dim + pe_dim)], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // block_H) + + T.annotate_layout({ + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + }) + + T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + for k in T.Pipelined(loop_range, num_stages=1): + T.copy( + K[bid, (seqlen_kv // num_split) * sid + + k * block_N:(seqlen_kv // num_split) * sid + (k + 1) * block_N, + cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.copy( + V[bid, (seqlen_kv // num_split) * sid + + k * block_N:(seqlen_kv // num_split) * sid + (k + 1) * block_N, + cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + T.copy(logsum, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, sid]) + T.copy(acc_o, O_shared) + T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, + sid, :]) + + @T.macro + def combine( + glse: T.Buffer([batch, heads, num_split], dtype), + Output_partial: T.Buffer(part_shape, dtype), + Output: T.Buffer(shape_o, dtype), + ): + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim], dtype) + o_accum_local = T.alloc_fragment([dim], accum_dtype) + lse_local = T.alloc_fragment([num_split, 1], dtype) + lse_local_split = T.alloc_local([1], accum_dtype) + lse_logsum_local = T.alloc_local([1], accum_dtype) + lse_max_local = T.alloc_fragment([1], accum_dtype) + scale_local = T.alloc_local([1], accum_dtype) + + T.annotate_layout({ + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + }) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + for k in T.Parallel(num_split): + lse_local[k, 0] = glse[bz, by, k] + T.reduce_max(lse_local, lse_max_local, dim=0, clear=True) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split[0] = glse[bz, by, k] + lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) + lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + for k in T.serial(num_split): + for i in T.Parallel(dim): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split[0] = glse[bz, by, k] + scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + for i in T.Parallel(dim): + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dim): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def main( + Q: T.Buffer(shape_q, dtype), + K: T.Buffer(shape_k, dtype), + V: T.Buffer(shape_v, dtype), + glse: T.Buffer([batch, heads, num_split], dtype), + Output_partial: T.Buffer(part_shape, dtype), # [batch, heads, num_split, dim] + Output: T.Buffer(shape_o, dtype), + ): + flash_attn_split(Q, K, V, glse, Output_partial) + combine(glse, Output_partial, Output) + + return main + + +def ref_program(query, key, value, glse, Output_partial): + # """ + # Inputs: + # - query (Tensor): [batch, heads, dim] + # - key (Tensor): [batch, seqlen_kv, kv_head_num, dim] + # - value (Tensor): [batch, seqlen_kv, kv_head_num, dim] + + # Outputs: + # - output (Tensor): [batch, heads, dim] + # """ + from einops import rearrange + batch_size, query_heads, dim = query.shape # [batch_size, query_heads, dim] + _, seqlen_kv, kv_heads, _ = key.shape # [batch_size, seqlen_kv, kv_heads, kv_dim] + dim_v = value.shape[-1] + assert kv_heads == 1, "kv_heads must be 1" + + query_expanded = rearrange(query, 'b h d -> b h 1 d') # [batch_size, query_heads, 1, dim] + key_expanded = key.expand(-1, -1, query_heads, -1) # [batch_size, query_heads, seqlen_kv, dim] + value_expanded = value.expand(-1, -1, query_heads, + -1) # [batch_size, query_heads, seqlen_kv, dim] + key_expanded = rearrange(key_expanded, + 'b n h d -> b h n d') # [batch_size, kv_head_num, seqlen_kv, dim] + value_expanded = rearrange(value_expanded, + 'b n h d -> b h n d') # [batch_size, query_heads, seqlen_kv, dim] + + scores = torch.matmul(query_expanded, + key_expanded.transpose(-1, -2)) # [batch_size, query_heads, 1, seqlen_kv] + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + attention_weights = F.softmax(scores, dim=-1) # [batch_size, query_heads, 1, seqlen_kv] + output = torch.matmul(attention_weights, value_expanded) # [batch_size, query_heads, 1, dim] + return output.view(batch_size, query_heads, dim_v) + + +def flash_split_ref(Q, K, V): + dim = 512 + pe_dim = 64 + batch = Q.size(0) + nheads = Q.size(1) + assert Q.size(2) == dim + pe_dim, "dim must be 576=512+64" + block_N = 32 + seqlen_kv = K.size(1) + + scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) + acc_s = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float) + acc_s_cast = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float16) + acc_o = torch.empty((batch, nheads, dim), device="cuda", dtype=torch.float) + scores_max = torch.empty((batch, nheads), device="cuda", dtype=torch.float) + scores_max_prev = torch.empty((batch, nheads), device="cuda", dtype=torch.float) + scores_scale = torch.empty((batch, nheads), device="cuda", dtype=torch.float) + scores_sum = torch.empty((batch, nheads), device="cuda", dtype=torch.float) + logsum = torch.empty((batch, nheads), device="cuda", dtype=torch.float) + gacc_o = torch.empty((num_split, batch, nheads, dim), device="cuda", dtype=torch.float) + glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float) + + Q_ = Q * scale + K_ = K.expand(-1, -1, nheads, -1) + V_ = V.expand(-1, -1, nheads, -1) + + for ks in range(num_split): + acc_o.fill_(0) + logsum.fill_(0) + scores_max.fill_(float('-inf')) + scores_max_prev.fill_(float('-inf')) + for i in range(int((seqlen_kv // num_split) / block_N)): + acc_s.fill_(0) + acc_s = torch.einsum('bhd,bkhd->bhk', Q_, + K_[:, (seqlen_kv // num_split) * ks + + i * block_N:(seqlen_kv // num_split) * ks + + (i + 1) * block_N, :, :]) # [batch, nheads, block_N] + scores_max_prev = scores_max + scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads] + scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads] + acc_o *= scores_scale[:, :, None] + acc_s = torch.exp2(acc_s - scores_max[:, :, None]) + acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N] + acc_o += torch.einsum( + 'bhk,bkhd->bhd', acc_s_cast, + V_[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + + (i + 1) * block_N, :, :]) + scores_sum = acc_s.sum(dim=-1, keepdim=False) + logsum = logsum * scores_scale + scores_sum + acc_o /= logsum[:, :, None] + logsum = torch.log2(logsum) + scores_max + gacc_o[ks, :, :, :] = acc_o + glogsum[ks, :, :] = logsum + + return glogsum.to(torch.float16).permute(1, 2, 0), gacc_o.to(torch.float16).permute(1, 2, 0, 3) + + +def reduce_ref(Q, K, V, glse, Output_partial): + o = torch.empty_like(Output_partial[:, :, 0, :]).fill_(0) + lse_logsum = torch.empty_like(glse[:, :, 0]).fill_(0) + lse_max = glse.max(dim=2, keepdim=False).values + for ks in range(num_split): + lse = glse[:, :, ks] + lse_logsum += torch.exp2(lse - lse_max) + lse_logsum = torch.log2(lse_logsum) + lse_max + for ks in range(num_split): + lse = glse[:, :, ks] + scale = torch.exp2(lse - lse_logsum) + o += Output_partial[:, :, ks, :] * scale[:, :, None] + return o.to(torch.float16) + + +if __name__ == "__main__": + BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE = 64, 128, 1, 8192, 512, 64 + qk_flops = 2 * BATCH * H_Q * KV_CTX * (D_HEAD + DPE) + pv_flops = 2 * BATCH * H_Q * KV_CTX * D_HEAD + total_flops = qk_flops + pv_flops + BLOCK_N = 32 # if D_HEAD <= 128 else 32 + BLOCK_H = 64 + + program = flashattn(BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE, BLOCK_N, BLOCK_H) + mod, params = tilelang.lower(program) + mod = tilelang.Profiler(mod, params, [5], tilelang.TensorSupplyType.Normal) + mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) + latency = mod.do_bench(mod.func, warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) \ No newline at end of file From 6948f4a2e2ac72ba29e02f8b110914d02e429ba4 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 23 Feb 2025 18:05:25 +0000 Subject: [PATCH 2/6] Add GEMM SplitK and StreamK example implementations This commit introduces two new example scripts demonstrating advanced GEMM (matrix multiplication) techniques: - `example_tilelang_gemm_splitk.py`: Implements a Split-K GEMM kernel using TileLang - `example_tilelang_gemm_streamk.py`: Implements a Stream-K GEMM kernel using TileLang Both examples showcase different parallel computation strategies for matrix multiplication, with comprehensive testing using PyTorch reference implementations. --- .../example_tilelang_gemm_splitk.py | 72 ++++++ .../example_tilelang_gemm_streamk.py | 206 ++++++++++++++++++ 2 files changed, 278 insertions(+) create mode 100644 examples/gemm_splitk/example_tilelang_gemm_splitk.py create mode 100644 examples/gemm_streamk/example_tilelang_gemm_streamk.py diff --git a/examples/gemm_splitk/example_tilelang_gemm_splitk.py b/examples/gemm_splitk/example_tilelang_gemm_splitk.py new file mode 100644 index 000000000..e552320eb --- /dev/null +++ b/examples/gemm_splitk/example_tilelang_gemm_splitk.py @@ -0,0 +1,72 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import tilelang +from tilelang import Profiler +import tilelang.language as T +from tvm import DataType + +def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_dtype="float"): + + splitK = K // split_k + + @T.prim_func + def main( + A: T.Buffer((M, K), dtype), + B: T.Buffer((N, K), dtype), + C: T.Buffer((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): + A_shared = T.alloc_shared((block_M, block_K), dtype, "shared") + B_shared = T.alloc_shared((block_K, block_N), dtype, "shared") + C_shared = T.alloc_shared((block_M, block_N), dtype, "shared") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + if bz == 0: + # fuse the zero initialization kernel + for i, j in T.Parallel(block_M, block_N): + m, n = by * block_M + i, bx * block_N + j + C[m, n] = T.cast(0, dtype) + + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(splitK, block_K), num_stages=0): + T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared) + T.copy(B[bz * splitK + ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C_shared) + + if DataType(dtype).bits == 16: + for i, j in T.Parallel(block_M, block_N // 2): + m, n = by * block_M + i, bx * block_N + j * 2 + # vectorized atomic + T.atomic_addx2( + C[m, n], C_shared[i, j * 2]) + else: + for i, j in T.Parallel(block_M, block_N): + T.atomic_add( + C[by * block_M + i, bx * block_N + j], C_shared[i, j]) + + + return main + + +program = matmul(1024, 1024, 1024, 128, 128, 32, 4) + +kernel = tilelang.compile(program) + +print(kernel.get_kernel_source()) + +import torch + +a = torch.randn(1024, 1024).cuda().half() +b = torch.randn(1024, 1024).cuda().half() +c = torch.zeros(1024, 1024).cuda().half() +kernel(a, b, c) + +ref_c = a @ b + +print(c) +print(ref_c) + +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) diff --git a/examples/gemm_streamk/example_tilelang_gemm_streamk.py b/examples/gemm_streamk/example_tilelang_gemm_streamk.py new file mode 100644 index 000000000..df7be58a8 --- /dev/null +++ b/examples/gemm_streamk/example_tilelang_gemm_streamk.py @@ -0,0 +1,206 @@ +import torch +import torch.backends +import tilelang +from tilelang import language as T +import math + + +def cdiv(a, b): + return math.ceil(a / b) + + +# disable tf32 +torch.backends.cuda.matmul.allow_tf32 = False + +m = 256 +n = 1024 +k = 512 + +total_sm = 108 + +torch.random.manual_seed(0) +# uniform distribution from -1 to 1 +A = torch.rand(m, k, device="cuda", dtype=torch.float16) * 2 - 1 +B = torch.rand(n, k, device="cuda", dtype=torch.float16) * 2 - 1 + +streamk_programs = total_sm +BLOCK_SIZE_M = 16 +BLOCK_SIZE_N = 128 +BLOCK_SIZE_K = 32 +two_tiles = False +M, K = A.shape +N, K = B.shape +# accumulator types +# compute grid (work to do per SM on the first wave) +num_block_m = cdiv(M, BLOCK_SIZE_M) +num_block_n = cdiv(N, BLOCK_SIZE_N) +iters_per_tile = cdiv(K, BLOCK_SIZE_K) +total_tiles = num_block_m * num_block_n + +# Two-tile SK + DP +streamk_tiles = total_tiles % streamk_programs +if ( + total_tiles - streamk_tiles > streamk_programs +): # (total_tiles // total_programs > 1) + streamk_tiles += streamk_programs + +blocking_tiles = total_tiles - streamk_tiles +streamk_iters = streamk_tiles * iters_per_tile + +streamk_full_tiles = streamk_iters // streamk_programs +streamk_partial_tiles = streamk_iters % streamk_programs + +print(f"{total_tiles=} ") +print(f"{iters_per_tile=} ") + +sm_patition_factor = max(blocking_tiles // total_sm, 1) + + +def tl_matmul_streamk( + M, + N, + K, + streamk_tiles, + block_M, + block_N, + block_K, + trans_A, + trans_B, + dtypeAB, + dtypeC, + accum_dtype, + num_stages, + threads, +): + assert not trans_A + A_shape = (M, K) if not trans_A else (K, M) + B_shape = (K, N) if not trans_B else (N, K) + A_shared_shape = (block_M, block_K) if not trans_A else (block_K, block_M) + B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K) + + @T.macro + def compute_first_wave( + pid: T.int32, + A_buf: T.Buffer, + A_buf_shared: T.Buffer, + B_buf: T.Buffer, + B_buf_shared: T.Buffer, + C: T.Buffer, + C_local: T.Buffer, + ): + start_iter = T.alloc_fragment((1,), "int32", "local") + end_iter = T.alloc_fragment((1,), "int32", "local") + + start_iter[0] = pid * streamk_full_tiles + T.min(pid, streamk_partial_tiles) + last_iter = (pid + 1) * streamk_full_tiles + T.min( + pid + 1, streamk_partial_tiles + ) + + while start_iter[0] < last_iter: + end_iter[0] = T.min( + start_iter[0] + (iters_per_tile - (start_iter[0] % iters_per_tile)), + last_iter, + ) + + tile_id = start_iter[0] // iters_per_tile + remain_iters = start_iter[0] % iters_per_tile + pid_m = tile_id // T.ceildiv(N, block_N) + pid_n = tile_id % T.ceildiv(N, block_N) + + T.clear(C_local) + for k in T.Pipelined(end_iter[0] - start_iter[0], num_stages=num_stages): + T.copy( + A_buf[pid_m * block_M, (k + (start_iter[0] % iters_per_tile)) * block_K], + A_buf_shared, + ) + T.copy( + B_buf[pid_n * block_N, (k + (start_iter[0] % iters_per_tile)) * block_K], + B_buf_shared, + ) + T.gemm(A_buf_shared, B_buf_shared, C_local, transpose_B=trans_B) + + # last iteration of the tile always happens before its start on another SM + if remain_iters == 0 and (end_iter[0] % iters_per_tile == 0): + T.copy(C_local, C[pid_m * block_M, pid_n * block_N]) + else: + for i, j in T.Parallel(block_M, block_N): + T.atomic_add( + C[pid_m * block_M + i, pid_n * block_N + j], C_local[i, j] + ) + + start_iter[0] = end_iter[0] + + @T.macro + def compute_full_tiles( + pid: T.int32, + A_buf: T.Buffer, + A_shared: T.Buffer, + B_buf: T.Buffer, + B_shared: T.Buffer, + C: T.Buffer, + C_local: T.Buffer, + ): + + for p in T.serial(sm_patition_factor): + tile_id = pid + streamk_tiles + p * total_sm + pid_m = tile_id // T.ceildiv(N, block_N) + pid_n = tile_id % T.ceildiv(N, block_N) + T.clear(C_local) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): + T.copy(A_buf[pid_m * block_M, k * block_K], A_shared) + T.copy(B_buf[pid_n * block_N, k * block_K], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=trans_B) + T.copy(C_local, C[pid_m * block_M, pid_n * block_N]) + + @T.prim_func + def main( + A: T.Buffer(A_shape, dtypeAB), + B: T.Buffer(B_shape, dtypeAB), + C: T.Buffer((M, N), dtypeC), + ): + with T.Kernel(streamk_programs, threads=threads) as pid: + + A_shared = T.alloc_shared(A_shared_shape, dtypeAB) + B_shared = T.alloc_shared(B_shared_shape, dtypeAB) + A_shared_full_tiles = T.alloc_shared(A_shared_shape, dtypeAB) + B_shared_full_tiles = T.alloc_shared(B_shared_shape, dtypeAB) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + compute_first_wave(pid, A, A_shared, B, B_shared, C, C_local) + + if sm_patition_factor > 0: + compute_full_tiles(pid, A, A_shared_full_tiles, B, B_shared_full_tiles, C, C_local) + + return main + + +_tl_matmul_streamk = tl_matmul_streamk( + m, + n, + k, + streamk_tiles, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + False, + True, + "float16", + "float16", + "float32", + 2, + 64, +) + +kernel = tilelang.compile(_tl_matmul_streamk) +print(kernel.get_kernel_source()) + +b_c = torch.zeros((m, n), device="cuda", dtype=torch.float16) + +kernel(A, B, b_c) + +C = torch.matmul(A, B.T) + +print(b_c) +print(C) +torch.testing.assert_close(C, b_c, rtol=1e-2, atol=1e-2) From bd76b56e91a6fc6da4b175fe5066bc239a78d675 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 23 Feb 2025 18:05:43 +0000 Subject: [PATCH 3/6] Refactor GEMM SplitK and StreamK example implementations Clean up and improve code formatting for the SplitK and StreamK GEMM example scripts: - Remove unused import (Profiler) in splitk example - Simplify line breaks and improve code readability - Standardize indentation and remove unnecessary whitespace - Optimize atomic add and copy operations for better clarity --- .../example_tilelang_gemm_splitk.py | 16 ++++++-------- .../example_tilelang_gemm_streamk.py | 22 +++++++------------ 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/examples/gemm_splitk/example_tilelang_gemm_splitk.py b/examples/gemm_splitk/example_tilelang_gemm_splitk.py index e552320eb..dbf06b7ad 100644 --- a/examples/gemm_splitk/example_tilelang_gemm_splitk.py +++ b/examples/gemm_splitk/example_tilelang_gemm_splitk.py @@ -2,12 +2,12 @@ # Licensed under the MIT License. import tilelang -from tilelang import Profiler import tilelang.language as T from tvm import DataType + def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_dtype="float"): - + splitK = K // split_k @T.prim_func @@ -16,7 +16,8 @@ def main( B: T.Buffer((N, K), dtype), C: T.Buffer((M, N), dtype), ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): A_shared = T.alloc_shared((block_M, block_K), dtype, "shared") B_shared = T.alloc_shared((block_K, block_N), dtype, "shared") C_shared = T.alloc_shared((block_M, block_N), dtype, "shared") @@ -33,20 +34,17 @@ def main( T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared) T.copy(B[bz * splitK + ko * block_K, bx * block_N], B_shared) T.gemm(A_shared, B_shared, C_local) - + T.copy(C_local, C_shared) if DataType(dtype).bits == 16: for i, j in T.Parallel(block_M, block_N // 2): m, n = by * block_M + i, bx * block_N + j * 2 # vectorized atomic - T.atomic_addx2( - C[m, n], C_shared[i, j * 2]) + T.atomic_addx2(C[m, n], C_shared[i, j * 2]) else: for i, j in T.Parallel(block_M, block_N): - T.atomic_add( - C[by * block_M + i, bx * block_N + j], C_shared[i, j]) - + T.atomic_add(C[by * block_M + i, bx * block_N + j], C_shared[i, j]) return main diff --git a/examples/gemm_streamk/example_tilelang_gemm_streamk.py b/examples/gemm_streamk/example_tilelang_gemm_streamk.py index df7be58a8..84a3e650a 100644 --- a/examples/gemm_streamk/example_tilelang_gemm_streamk.py +++ b/examples/gemm_streamk/example_tilelang_gemm_streamk.py @@ -39,9 +39,7 @@ def cdiv(a, b): # Two-tile SK + DP streamk_tiles = total_tiles % streamk_programs -if ( - total_tiles - streamk_tiles > streamk_programs -): # (total_tiles // total_programs > 1) +if (total_tiles - streamk_tiles > streamk_programs): # (total_tiles // total_programs > 1) streamk_tiles += streamk_programs blocking_tiles = total_tiles - streamk_tiles @@ -92,9 +90,7 @@ def compute_first_wave( end_iter = T.alloc_fragment((1,), "int32", "local") start_iter[0] = pid * streamk_full_tiles + T.min(pid, streamk_partial_tiles) - last_iter = (pid + 1) * streamk_full_tiles + T.min( - pid + 1, streamk_partial_tiles - ) + last_iter = (pid + 1) * streamk_full_tiles + T.min(pid + 1, streamk_partial_tiles) while start_iter[0] < last_iter: end_iter[0] = T.min( @@ -124,9 +120,7 @@ def compute_first_wave( T.copy(C_local, C[pid_m * block_M, pid_n * block_N]) else: for i, j in T.Parallel(block_M, block_N): - T.atomic_add( - C[pid_m * block_M + i, pid_n * block_N + j], C_local[i, j] - ) + T.atomic_add(C[pid_m * block_M + i, pid_n * block_N + j], C_local[i, j]) start_iter[0] = end_iter[0] @@ -155,9 +149,9 @@ def compute_full_tiles( @T.prim_func def main( - A: T.Buffer(A_shape, dtypeAB), - B: T.Buffer(B_shape, dtypeAB), - C: T.Buffer((M, N), dtypeC), + A: T.Buffer(A_shape, dtypeAB), + B: T.Buffer(B_shape, dtypeAB), + C: T.Buffer((M, N), dtypeC), ): with T.Kernel(streamk_programs, threads=threads) as pid: @@ -166,9 +160,9 @@ def main( A_shared_full_tiles = T.alloc_shared(A_shared_shape, dtypeAB) B_shared_full_tiles = T.alloc_shared(B_shared_shape, dtypeAB) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - + compute_first_wave(pid, A, A_shared, B, B_shared, C, C_local) - + if sm_patition_factor > 0: compute_full_tiles(pid, A, A_shared_full_tiles, B, B_shared_full_tiles, C, C_local) From c2cffd9afd38c4f0869ef313e591575bcc19a44c Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 24 Feb 2025 09:19:17 +0000 Subject: [PATCH 4/6] Add block sparse attention benchmarks for multiple libraries This commit introduces comprehensive block sparse attention benchmarks for different libraries: - TileLang block sparse FMHA implementation - Triton block sparse FMHA implementation - PyTorch reference block sparse FMHA implementation - FlashAttention dense FMHA reference implementation The benchmarks include: - Configurable benchmark parameters (batch size, heads, sequence length, etc.) - Sparse mask generation using top-k and threshold methods - Performance measurement for different sparse attention configurations - Utility functions for mask generation and benchmarking --- .../benchmark_configs.py | 2 + .../benchmark_library_dense_fmha.py | 57 ++++ .../benchmark_tilelang_block_sparse_fmha.py | 213 ++++++++++++ .../benchmark_torch_block_sparse_fmha.py | 78 +++++ .../benchmark_triton_block_sparse_fmha.py | 308 ++++++++++++++++++ .../blocksparse_attention/requirements.txt | 1 + 6 files changed, 659 insertions(+) create mode 100644 benchmark/blocksparse_attention/benchmark_configs.py create mode 100644 benchmark/blocksparse_attention/benchmark_library_dense_fmha.py create mode 100644 benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py create mode 100644 benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py create mode 100644 benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py create mode 100644 benchmark/blocksparse_attention/requirements.txt diff --git a/benchmark/blocksparse_attention/benchmark_configs.py b/benchmark/blocksparse_attention/benchmark_configs.py new file mode 100644 index 000000000..a23e2136a --- /dev/null +++ b/benchmark/blocksparse_attention/benchmark_configs.py @@ -0,0 +1,2 @@ +# BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK +configs = [[4, 2, 256, 64, 2, 64]] diff --git a/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py new file mode 100644 index 000000000..b1bbcbed2 --- /dev/null +++ b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py @@ -0,0 +1,57 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import math +import torch + +import torch.nn.functional as F +from tilelang.profiler import do_bench + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], + False, + dtype=torch.bool, + device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def benchmark_topk_sparse_attention(): + from benchmark_configs import configs + torch.manual_seed(0) + + # Config + for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: + + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + + import flash_attn + def benchmark_fn(): + flash_attn.flash_attn_func(q, k, v, causal=True) + + ref_latency = do_bench( + benchmark_fn, + warmup=10, + rep=100, + ) + print(f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}") + + +if __name__ == "__main__": + benchmark_topk_sparse_attention() diff --git a/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py new file mode 100644 index 000000000..64a44e5b5 --- /dev/null +++ b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py @@ -0,0 +1,213 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import math +import torch + +import tilelang +from tilelang import language as T +from tilelang.profiler import do_bench + +def is_hip(): + return False + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], + False, + dtype=torch.bool, + device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal): + block_M = 64 + block_N = 64 + num_stages = 0 + threads = 128 + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + shape = [batch, heads, seq_len, dim] + block_mask_shape = [batch, heads, downsample_len, downsample_len] + + dtype = "float16" + accum_dtype = "float" + block_mask_dtype = "int8" + + def kernel_func(block_M, block_N, num_stages, threads): + + @T.macro + def MMA0( + K: T.Buffer(shape, dtype), + Q_shared: T.Buffer([block_M, dim], dtype), + K_shared: T.Buffer([block_N, dim], dtype), + acc_s: T.Buffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, + -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Buffer(shape, dtype), + V_shared: T.Buffer([block_M, dim], dtype), + acc_s_cast: T.Buffer([block_M, block_N], dtype), + acc_o: T.Buffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.Buffer([block_M, block_N], accum_dtype), + acc_s_cast: T.Buffer([block_M, block_N], dtype), + scores_max: T.Buffer([block_M], accum_dtype), + scores_max_prev: T.Buffer([block_M], accum_dtype), + scores_scale: T.Buffer([block_M], accum_dtype), + scores_sum: T.Buffer([block_M], accum_dtype), + logsum: T.Buffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.Buffer([block_M, dim], accum_dtype), + scores_scale: T.Buffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Buffer(shape, dtype), + K: T.Buffer(shape, dtype), + V: T.Buffer(shape, dtype), + BlockSparseMask: T.Buffer(block_mask_shape, block_mask_dtype), + Output: T.Buffer(shape, dtype), + ): + with T.Kernel( + T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + block_mask = T.alloc_local([downsample_len], block_mask_dtype) + + T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + for vj in T.serial(downsample_len): + block_mask[vj] = BlockSparseMask[bz, by, bx, vj] + + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv( + (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + if block_mask[k] != 0: + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, + scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + + return main + + return kernel_func(block_M, block_N, num_stages, threads) + + +def benchmark_topk_sparse_attention(): + from benchmark_configs import configs + torch.manual_seed(0) + + # Config + for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: + + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + + sm_scale = 1.0 / (D_HEAD**0.5) + + # Create sparse mask (downsampled to block level) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], + device='cuda', + dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + kernel = tilelang.compile(program, out_idx=4) + def benchmark_fn(): + # Compute reference + # Expand block mask to full attention matrix + kernel(q, k, v, block_mask) + + ref_latency = do_bench( + benchmark_fn, + warmup=10, + rep=100, + ) + print(f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}") + + +if __name__ == "__main__": + benchmark_topk_sparse_attention() diff --git a/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py new file mode 100644 index 000000000..86dbef000 --- /dev/null +++ b/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import math +import torch + +import torch.nn.functional as F +from tilelang.profiler import do_bench + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], + False, + dtype=torch.bool, + device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def benchmark_topk_sparse_attention(): + from benchmark_configs import configs + torch.manual_seed(0) + + # Config + for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: + + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + + sm_scale = 1.0 / (D_HEAD**0.5) + + # Create sparse mask (downsampled to block level) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], + device='cuda', + dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + + def benchmark_fn(): + # Compute reference + # Expand block mask to full attention matrix + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() + full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal + + # PyTorch reference implementation + attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float('-inf')) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + return ref_output + + ref_latency = do_bench( + benchmark_fn, + warmup=10, + rep=100, + ) + print(f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}") + + +if __name__ == "__main__": + benchmark_topk_sparse_attention() diff --git a/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py new file mode 100644 index 000000000..816ede239 --- /dev/null +++ b/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py @@ -0,0 +1,308 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import math +import torch + +import triton +import triton.language as tl +from tilelang.profiler import do_bench + +def is_hip(): + return False + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], + False, + dtype=torch.bool, + device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +@triton.jit +def _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + k_block_col_idx, + block_mask_ptr, + k_ptrs, + v_ptrs, + offs_m, + offs_n, + stride_kt, + stride_vt, + stride_bmask_n, + sm_scale, + seqlen_k, + past_len, + LAST_K_BLOCK: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + + mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) + + if mask_val == True: + start_n = k_block_col_idx * BLOCK_N + # -- compute qk ---- + + k = tl.load(k_ptrs + start_n * stride_kt) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + + qk *= sm_scale + + # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N + if LAST_K_BLOCK: + qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, + float('-inf')) + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.exp(qk) + l_ij = tl.sum(p, 1) + alpha = tl.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + + # update acc + v = tl.load(v_ptrs + start_n * stride_vt) + + p = p.to(v.type.element_ty) + + acc += tl.dot(p, v) + # update m_i and l_i + m_i = m_ij + return acc, l_i, m_i + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + block_mask_ptr, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qd, + stride_kz, + stride_kh, + stride_kn, + stride_kd, + stride_vz, + stride_vh, + stride_vn, + stride_vd, + stride_bmz, + stride_bmh, + stride_bmm, + stride_bmn, + stride_oz, + stride_oh, + stride_om, + stride_od, + H, + N_CTX, + PAST_LEN, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + Q_LEN = N_CTX - PAST_LEN + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_h = off_hz % H + off_z = off_hz // H + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_h * stride_kh + V += off_z * stride_vz + off_h * stride_vh + block_mask_ptr += off_z * stride_bmz + off_h * stride_bmh + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + # off_k = offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd + off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + mask_ptrs = block_mask_ptr + start_m * stride_bmm + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN) + + k_block_start = 0 + k_block_end = tl.cdiv((start_m + 1) * BLOCK_M, BLOCK_N) + + # loop over k, v and update accumulator + for col_idx in range(k_block_start, k_block_end): + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + col_idx, + mask_ptrs, + k_ptrs, + v_ptrs, + offs_m, + offs_n, + stride_kn, + stride_vn, + stride_bmn, + sm_scale, + N_CTX, + PAST_LEN, + col_idx == k_block_end - 1, + BLOCK_M, + BLOCK_N, + ) + + m_i += tl.math.log(l_i) + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + acc = acc.to(Out.dtype.element_ty) + + off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[ + None, :] * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) + + +def _forward(ctx, + q, + k, + v, + block_sparse_mask, + sm_scale, + BLOCK_M=64, + BLOCK_N=64, + num_warps=None, + num_stages=1, + out=None): + + assert q.shape[-1] == k.shape[-1] == v.shape[-1] + assert k.shape[2] == v.shape[2] + o = out if out is not None else torch.empty_like(q).contiguous() + grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1]) + + assert q.shape[-1] in [64, 128] + BLOCK_DMODEL = q.shape[-1] + + if is_hip(): + num_warps, num_stages = 8, 1 + else: + num_warps, num_stages = 4, 2 + + N_CTX = k.shape[2] + PAST_LEN = N_CTX - q.shape[2] + + H = q.shape[1] + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + block_sparse_mask, + o, + *q.stride(), + *k.stride(), + *v.stride(), + *block_sparse_mask.stride(), + *o.stride(), + H, + N_CTX, + PAST_LEN, + BLOCK_M, + BLOCK_N, + BLOCK_DMODEL, + num_warps=num_warps, + num_stages=num_stages, + ) + + return o + + +class _sparse_attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, block_sparse_dense, sm_scale): + # shape constraints + return _forward(ctx, q, k, v, block_sparse_dense, sm_scale) + + @staticmethod + def backward(ctx, do): + # No gradient propagation. + raise NotImplementedError("It does not support gradient propagation yet") + return None, None, None, None, None + + +block_sparse_triton_fn = _sparse_attention.apply + + +def benchmark_topk_sparse_attention(): + from benchmark_configs import configs + torch.manual_seed(0) + + # Config + for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: + + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + + sm_scale = 1.0 / (D_HEAD**0.5) + + # Create sparse mask (downsampled to block level) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], + device='cuda', + dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + + def benchmark_fn(): + # Compute reference + # Expand block mask to full attention matrix + block_sparse_triton_fn(q, k, v, block_mask, sm_scale) + + ref_latency = do_bench( + benchmark_fn, + warmup=10, + rep=100, + ) + print(f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}") + + +if __name__ == "__main__": + benchmark_topk_sparse_attention() diff --git a/benchmark/blocksparse_attention/requirements.txt b/benchmark/blocksparse_attention/requirements.txt new file mode 100644 index 000000000..d0f509936 --- /dev/null +++ b/benchmark/blocksparse_attention/requirements.txt @@ -0,0 +1 @@ +flash-attn From 789781cf3143f9cbe69f093941f764bc254fb678 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 24 Feb 2025 09:40:27 +0000 Subject: [PATCH 5/6] Refactor block sparse attention benchmarks with code style improvements - Add Ruff linter ignore comments to benchmark files - Improve code formatting and line breaks - Remove unused imports - Standardize print statement formatting - Enhance code readability across multiple library benchmarks --- .../benchmark_library_dense_fmha.py | 9 +++++---- .../benchmark_tilelang_block_sparse_fmha.py | 15 +++++++++++---- .../benchmark_torch_block_sparse_fmha.py | 12 ++++++++---- .../benchmark_triton_block_sparse_fmha.py | 15 ++++++++++----- 4 files changed, 34 insertions(+), 17 deletions(-) diff --git a/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py index b1bbcbed2..f3759a594 100644 --- a/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py @@ -1,9 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import math +# ruff: noqa import torch - -import torch.nn.functional as F from tilelang.profiler import do_bench def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): @@ -42,6 +40,7 @@ def benchmark_topk_sparse_attention(): v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) import flash_attn + def benchmark_fn(): flash_attn.flash_attn_func(q, k, v, causal=True) @@ -50,7 +49,9 @@ def benchmark_fn(): warmup=10, rep=100, ) - print(f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}") + print( + f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}" + ) if __name__ == "__main__": diff --git a/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py index 64a44e5b5..e3fec73fb 100644 --- a/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# ruff: noqa import math import torch @@ -7,9 +8,11 @@ from tilelang import language as T from tilelang.profiler import do_bench + def is_hip(): return False + def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK @@ -190,12 +193,14 @@ def benchmark_topk_sparse_attention(): downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + device='cuda', + dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) - program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + program = blocksparse_flashattn( + BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) kernel = tilelang.compile(program, out_idx=4) + def benchmark_fn(): # Compute reference # Expand block mask to full attention matrix @@ -206,7 +211,9 @@ def benchmark_fn(): warmup=10, rep=100, ) - print(f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}") + print( + f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}" + ) if __name__ == "__main__": diff --git a/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py index 86dbef000..66d3ec8f7 100644 --- a/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py @@ -1,11 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# ruff: noqa import math import torch import torch.nn.functional as F from tilelang.profiler import do_bench + def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK @@ -47,11 +49,11 @@ def benchmark_topk_sparse_attention(): downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + device='cuda', + dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) - + def benchmark_fn(): # Compute reference # Expand block mask to full attention matrix @@ -71,7 +73,9 @@ def benchmark_fn(): warmup=10, rep=100, ) - print(f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}") + print( + f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}" + ) if __name__ == "__main__": diff --git a/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py index 816ede239..e86590c5e 100644 --- a/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# ruff: noqa import math import torch @@ -7,9 +8,11 @@ import triton.language as tl from tilelang.profiler import do_bench + def is_hip(): return False + def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK @@ -286,22 +289,24 @@ def benchmark_topk_sparse_attention(): downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + device='cuda', + dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) - + def benchmark_fn(): # Compute reference # Expand block mask to full attention matrix - block_sparse_triton_fn(q, k, v, block_mask, sm_scale) + block_sparse_triton_fn(q, k, v, block_mask, sm_scale) # noqa: B023 ref_latency = do_bench( benchmark_fn, warmup=10, rep=100, ) - print(f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}") + print( + f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}" + ) if __name__ == "__main__": From 166ef78b236d94e365198c5285f0dd223813e9c6 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 24 Feb 2025 09:40:42 +0000 Subject: [PATCH 6/6] lint fix --- benchmark/blocksparse_attention/benchmark_library_dense_fmha.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py index f3759a594..f22a3a910 100644 --- a/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py @@ -4,6 +4,7 @@ import torch from tilelang.profiler import do_bench + def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK