From 7d35ac5b6ef83aa15c70e339d0e0a15e4977e9ee Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 23 Feb 2025 17:39:48 +0000 Subject: [PATCH 01/11] Add DeepSeek MLA decode example with Flash Attention implementation --- examples/deepseek_mla/.gitkeep | 0 examples/deepseek_mla/example_mla_decode.py | 267 ++++++++++++++++++++ 2 files changed, 267 insertions(+) delete mode 100644 examples/deepseek_mla/.gitkeep create mode 100644 examples/deepseek_mla/example_mla_decode.py diff --git a/examples/deepseek_mla/.gitkeep b/examples/deepseek_mla/.gitkeep deleted file mode 100644 index e69de29bb..000000000 diff --git a/examples/deepseek_mla/example_mla_decode.py b/examples/deepseek_mla/example_mla_decode.py new file mode 100644 index 000000000..91ddd2824 --- /dev/null +++ b/examples/deepseek_mla/example_mla_decode.py @@ -0,0 +1,267 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T + +num_split = 4 + + +def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H): + scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) + shape_q = [batch, heads, (dim + pe_dim)] + shape_k = [batch, seqlen_kv, kv_head_num, (dim + pe_dim)] + shape_v = [batch, seqlen_kv, kv_head_num, dim] + shape_o = [batch, heads, dim] + part_shape = [batch, heads, num_split, dim] + dtype = "float16" + accum_dtype = "float" + kv_group_num = heads // kv_head_num + VALID_BLOCK_H = min(block_H, kv_group_num) + assert kv_head_num == 1, "kv_head_num must be 1" + + @T.macro + def flash_attn_split( + Q: T.Buffer(shape_q, dtype), + K: T.Buffer(shape_k, dtype), + V: T.Buffer(shape_v, dtype), + glse: T.Buffer([batch, heads, num_split], dtype), + Output_partial: T.Buffer(part_shape, dtype), + ): + with T.Kernel( + batch, heads // min(block_H, kv_group_num), num_split, threads=128) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, (dim + pe_dim)], dtype) + K_shared = T.alloc_shared([block_N, (dim + pe_dim)], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_H, dim], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // block_H) + + T.annotate_layout({ + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + }) + + T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv((seqlen_kv // num_split), block_N) + for k in T.Pipelined(loop_range, num_stages=1): + T.copy( + K[bid, (seqlen_kv // num_split) * sid + + k * block_N:(seqlen_kv // num_split) * sid + (k + 1) * block_N, + cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] *= scores_scale[i] + T.copy( + V[bid, (seqlen_kv // num_split) * sid + + k * block_N:(seqlen_kv // num_split) * sid + (k + 1) * block_N, + cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, dim): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + T.copy(logsum, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, sid]) + T.copy(acc_o, O_shared) + T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, + sid, :]) + + @T.macro + def combine( + glse: T.Buffer([batch, heads, num_split], dtype), + Output_partial: T.Buffer(part_shape, dtype), + Output: T.Buffer(shape_o, dtype), + ): + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim], dtype) + o_accum_local = T.alloc_fragment([dim], accum_dtype) + lse_local = T.alloc_fragment([num_split, 1], dtype) + lse_local_split = T.alloc_local([1], accum_dtype) + lse_logsum_local = T.alloc_local([1], accum_dtype) + lse_max_local = T.alloc_fragment([1], accum_dtype) + scale_local = T.alloc_local([1], accum_dtype) + + T.annotate_layout({ + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + }) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + for k in T.Parallel(num_split): + lse_local[k, 0] = glse[bz, by, k] + T.reduce_max(lse_local, lse_max_local, dim=0, clear=True) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split[0] = glse[bz, by, k] + lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) + lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + for k in T.serial(num_split): + for i in T.Parallel(dim): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split[0] = glse[bz, by, k] + scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + for i in T.Parallel(dim): + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dim): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def main( + Q: T.Buffer(shape_q, dtype), + K: T.Buffer(shape_k, dtype), + V: T.Buffer(shape_v, dtype), + glse: T.Buffer([batch, heads, num_split], dtype), + Output_partial: T.Buffer(part_shape, dtype), # [batch, heads, num_split, dim] + Output: T.Buffer(shape_o, dtype), + ): + flash_attn_split(Q, K, V, glse, Output_partial) + combine(glse, Output_partial, Output) + + return main + + +def ref_program(query, key, value, glse, Output_partial): + # """ + # Inputs: + # - query (Tensor): [batch, heads, dim] + # - key (Tensor): [batch, seqlen_kv, kv_head_num, dim] + # - value (Tensor): [batch, seqlen_kv, kv_head_num, dim] + + # Outputs: + # - output (Tensor): [batch, heads, dim] + # """ + from einops import rearrange + batch_size, query_heads, dim = query.shape # [batch_size, query_heads, dim] + _, seqlen_kv, kv_heads, _ = key.shape # [batch_size, seqlen_kv, kv_heads, kv_dim] + dim_v = value.shape[-1] + assert kv_heads == 1, "kv_heads must be 1" + + query_expanded = rearrange(query, 'b h d -> b h 1 d') # [batch_size, query_heads, 1, dim] + key_expanded = key.expand(-1, -1, query_heads, -1) # [batch_size, query_heads, seqlen_kv, dim] + value_expanded = value.expand(-1, -1, query_heads, + -1) # [batch_size, query_heads, seqlen_kv, dim] + key_expanded = rearrange(key_expanded, + 'b n h d -> b h n d') # [batch_size, kv_head_num, seqlen_kv, dim] + value_expanded = rearrange(value_expanded, + 'b n h d -> b h n d') # [batch_size, query_heads, seqlen_kv, dim] + + scores = torch.matmul(query_expanded, + key_expanded.transpose(-1, -2)) # [batch_size, query_heads, 1, seqlen_kv] + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + attention_weights = F.softmax(scores, dim=-1) # [batch_size, query_heads, 1, seqlen_kv] + output = torch.matmul(attention_weights, value_expanded) # [batch_size, query_heads, 1, dim] + return output.view(batch_size, query_heads, dim_v) + + +def flash_split_ref(Q, K, V): + dim = 512 + pe_dim = 64 + batch = Q.size(0) + nheads = Q.size(1) + assert Q.size(2) == dim + pe_dim, "dim must be 576=512+64" + block_N = 32 + seqlen_kv = K.size(1) + + scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e) + acc_s = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float) + acc_s_cast = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float16) + acc_o = torch.empty((batch, nheads, dim), device="cuda", dtype=torch.float) + scores_max = torch.empty((batch, nheads), device="cuda", dtype=torch.float) + scores_max_prev = torch.empty((batch, nheads), device="cuda", dtype=torch.float) + scores_scale = torch.empty((batch, nheads), device="cuda", dtype=torch.float) + scores_sum = torch.empty((batch, nheads), device="cuda", dtype=torch.float) + logsum = torch.empty((batch, nheads), device="cuda", dtype=torch.float) + gacc_o = torch.empty((num_split, batch, nheads, dim), device="cuda", dtype=torch.float) + glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float) + + Q_ = Q * scale + K_ = K.expand(-1, -1, nheads, -1) + V_ = V.expand(-1, -1, nheads, -1) + + for ks in range(num_split): + acc_o.fill_(0) + logsum.fill_(0) + scores_max.fill_(float('-inf')) + scores_max_prev.fill_(float('-inf')) + for i in range(int((seqlen_kv // num_split) / block_N)): + acc_s.fill_(0) + acc_s = torch.einsum('bhd,bkhd->bhk', Q_, + K_[:, (seqlen_kv // num_split) * ks + + i * block_N:(seqlen_kv // num_split) * ks + + (i + 1) * block_N, :, :]) # [batch, nheads, block_N] + scores_max_prev = scores_max + scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads] + scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads] + acc_o *= scores_scale[:, :, None] + acc_s = torch.exp2(acc_s - scores_max[:, :, None]) + acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N] + acc_o += torch.einsum( + 'bhk,bkhd->bhd', acc_s_cast, + V_[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + + (i + 1) * block_N, :, :]) + scores_sum = acc_s.sum(dim=-1, keepdim=False) + logsum = logsum * scores_scale + scores_sum + acc_o /= logsum[:, :, None] + logsum = torch.log2(logsum) + scores_max + gacc_o[ks, :, :, :] = acc_o + glogsum[ks, :, :] = logsum + + return glogsum.to(torch.float16).permute(1, 2, 0), gacc_o.to(torch.float16).permute(1, 2, 0, 3) + + +def reduce_ref(Q, K, V, glse, Output_partial): + o = torch.empty_like(Output_partial[:, :, 0, :]).fill_(0) + lse_logsum = torch.empty_like(glse[:, :, 0]).fill_(0) + lse_max = glse.max(dim=2, keepdim=False).values + for ks in range(num_split): + lse = glse[:, :, ks] + lse_logsum += torch.exp2(lse - lse_max) + lse_logsum = torch.log2(lse_logsum) + lse_max + for ks in range(num_split): + lse = glse[:, :, ks] + scale = torch.exp2(lse - lse_logsum) + o += Output_partial[:, :, ks, :] * scale[:, :, None] + return o.to(torch.float16) + + +if __name__ == "__main__": + BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE = 64, 128, 1, 8192, 512, 64 + qk_flops = 2 * BATCH * H_Q * KV_CTX * (D_HEAD + DPE) + pv_flops = 2 * BATCH * H_Q * KV_CTX * D_HEAD + total_flops = qk_flops + pv_flops + BLOCK_N = 32 # if D_HEAD <= 128 else 32 + BLOCK_H = 64 + + program = flashattn(BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE, BLOCK_N, BLOCK_H) + mod, params = tilelang.lower(program) + mod = tilelang.Profiler(mod, params, [5], tilelang.TensorSupplyType.Normal) + mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) + latency = mod.do_bench(mod.func, warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) \ No newline at end of file From 6948f4a2e2ac72ba29e02f8b110914d02e429ba4 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 23 Feb 2025 18:05:25 +0000 Subject: [PATCH 02/11] Add GEMM SplitK and StreamK example implementations This commit introduces two new example scripts demonstrating advanced GEMM (matrix multiplication) techniques: - `example_tilelang_gemm_splitk.py`: Implements a Split-K GEMM kernel using TileLang - `example_tilelang_gemm_streamk.py`: Implements a Stream-K GEMM kernel using TileLang Both examples showcase different parallel computation strategies for matrix multiplication, with comprehensive testing using PyTorch reference implementations. --- .../example_tilelang_gemm_splitk.py | 72 ++++++ .../example_tilelang_gemm_streamk.py | 206 ++++++++++++++++++ 2 files changed, 278 insertions(+) create mode 100644 examples/gemm_splitk/example_tilelang_gemm_splitk.py create mode 100644 examples/gemm_streamk/example_tilelang_gemm_streamk.py diff --git a/examples/gemm_splitk/example_tilelang_gemm_splitk.py b/examples/gemm_splitk/example_tilelang_gemm_splitk.py new file mode 100644 index 000000000..e552320eb --- /dev/null +++ b/examples/gemm_splitk/example_tilelang_gemm_splitk.py @@ -0,0 +1,72 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import tilelang +from tilelang import Profiler +import tilelang.language as T +from tvm import DataType + +def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_dtype="float"): + + splitK = K // split_k + + @T.prim_func + def main( + A: T.Buffer((M, K), dtype), + B: T.Buffer((N, K), dtype), + C: T.Buffer((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): + A_shared = T.alloc_shared((block_M, block_K), dtype, "shared") + B_shared = T.alloc_shared((block_K, block_N), dtype, "shared") + C_shared = T.alloc_shared((block_M, block_N), dtype, "shared") + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + if bz == 0: + # fuse the zero initialization kernel + for i, j in T.Parallel(block_M, block_N): + m, n = by * block_M + i, bx * block_N + j + C[m, n] = T.cast(0, dtype) + + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(splitK, block_K), num_stages=0): + T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared) + T.copy(B[bz * splitK + ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C_shared) + + if DataType(dtype).bits == 16: + for i, j in T.Parallel(block_M, block_N // 2): + m, n = by * block_M + i, bx * block_N + j * 2 + # vectorized atomic + T.atomic_addx2( + C[m, n], C_shared[i, j * 2]) + else: + for i, j in T.Parallel(block_M, block_N): + T.atomic_add( + C[by * block_M + i, bx * block_N + j], C_shared[i, j]) + + + return main + + +program = matmul(1024, 1024, 1024, 128, 128, 32, 4) + +kernel = tilelang.compile(program) + +print(kernel.get_kernel_source()) + +import torch + +a = torch.randn(1024, 1024).cuda().half() +b = torch.randn(1024, 1024).cuda().half() +c = torch.zeros(1024, 1024).cuda().half() +kernel(a, b, c) + +ref_c = a @ b + +print(c) +print(ref_c) + +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) diff --git a/examples/gemm_streamk/example_tilelang_gemm_streamk.py b/examples/gemm_streamk/example_tilelang_gemm_streamk.py new file mode 100644 index 000000000..df7be58a8 --- /dev/null +++ b/examples/gemm_streamk/example_tilelang_gemm_streamk.py @@ -0,0 +1,206 @@ +import torch +import torch.backends +import tilelang +from tilelang import language as T +import math + + +def cdiv(a, b): + return math.ceil(a / b) + + +# disable tf32 +torch.backends.cuda.matmul.allow_tf32 = False + +m = 256 +n = 1024 +k = 512 + +total_sm = 108 + +torch.random.manual_seed(0) +# uniform distribution from -1 to 1 +A = torch.rand(m, k, device="cuda", dtype=torch.float16) * 2 - 1 +B = torch.rand(n, k, device="cuda", dtype=torch.float16) * 2 - 1 + +streamk_programs = total_sm +BLOCK_SIZE_M = 16 +BLOCK_SIZE_N = 128 +BLOCK_SIZE_K = 32 +two_tiles = False +M, K = A.shape +N, K = B.shape +# accumulator types +# compute grid (work to do per SM on the first wave) +num_block_m = cdiv(M, BLOCK_SIZE_M) +num_block_n = cdiv(N, BLOCK_SIZE_N) +iters_per_tile = cdiv(K, BLOCK_SIZE_K) +total_tiles = num_block_m * num_block_n + +# Two-tile SK + DP +streamk_tiles = total_tiles % streamk_programs +if ( + total_tiles - streamk_tiles > streamk_programs +): # (total_tiles // total_programs > 1) + streamk_tiles += streamk_programs + +blocking_tiles = total_tiles - streamk_tiles +streamk_iters = streamk_tiles * iters_per_tile + +streamk_full_tiles = streamk_iters // streamk_programs +streamk_partial_tiles = streamk_iters % streamk_programs + +print(f"{total_tiles=} ") +print(f"{iters_per_tile=} ") + +sm_patition_factor = max(blocking_tiles // total_sm, 1) + + +def tl_matmul_streamk( + M, + N, + K, + streamk_tiles, + block_M, + block_N, + block_K, + trans_A, + trans_B, + dtypeAB, + dtypeC, + accum_dtype, + num_stages, + threads, +): + assert not trans_A + A_shape = (M, K) if not trans_A else (K, M) + B_shape = (K, N) if not trans_B else (N, K) + A_shared_shape = (block_M, block_K) if not trans_A else (block_K, block_M) + B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K) + + @T.macro + def compute_first_wave( + pid: T.int32, + A_buf: T.Buffer, + A_buf_shared: T.Buffer, + B_buf: T.Buffer, + B_buf_shared: T.Buffer, + C: T.Buffer, + C_local: T.Buffer, + ): + start_iter = T.alloc_fragment((1,), "int32", "local") + end_iter = T.alloc_fragment((1,), "int32", "local") + + start_iter[0] = pid * streamk_full_tiles + T.min(pid, streamk_partial_tiles) + last_iter = (pid + 1) * streamk_full_tiles + T.min( + pid + 1, streamk_partial_tiles + ) + + while start_iter[0] < last_iter: + end_iter[0] = T.min( + start_iter[0] + (iters_per_tile - (start_iter[0] % iters_per_tile)), + last_iter, + ) + + tile_id = start_iter[0] // iters_per_tile + remain_iters = start_iter[0] % iters_per_tile + pid_m = tile_id // T.ceildiv(N, block_N) + pid_n = tile_id % T.ceildiv(N, block_N) + + T.clear(C_local) + for k in T.Pipelined(end_iter[0] - start_iter[0], num_stages=num_stages): + T.copy( + A_buf[pid_m * block_M, (k + (start_iter[0] % iters_per_tile)) * block_K], + A_buf_shared, + ) + T.copy( + B_buf[pid_n * block_N, (k + (start_iter[0] % iters_per_tile)) * block_K], + B_buf_shared, + ) + T.gemm(A_buf_shared, B_buf_shared, C_local, transpose_B=trans_B) + + # last iteration of the tile always happens before its start on another SM + if remain_iters == 0 and (end_iter[0] % iters_per_tile == 0): + T.copy(C_local, C[pid_m * block_M, pid_n * block_N]) + else: + for i, j in T.Parallel(block_M, block_N): + T.atomic_add( + C[pid_m * block_M + i, pid_n * block_N + j], C_local[i, j] + ) + + start_iter[0] = end_iter[0] + + @T.macro + def compute_full_tiles( + pid: T.int32, + A_buf: T.Buffer, + A_shared: T.Buffer, + B_buf: T.Buffer, + B_shared: T.Buffer, + C: T.Buffer, + C_local: T.Buffer, + ): + + for p in T.serial(sm_patition_factor): + tile_id = pid + streamk_tiles + p * total_sm + pid_m = tile_id // T.ceildiv(N, block_N) + pid_n = tile_id % T.ceildiv(N, block_N) + T.clear(C_local) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): + T.copy(A_buf[pid_m * block_M, k * block_K], A_shared) + T.copy(B_buf[pid_n * block_N, k * block_K], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=trans_B) + T.copy(C_local, C[pid_m * block_M, pid_n * block_N]) + + @T.prim_func + def main( + A: T.Buffer(A_shape, dtypeAB), + B: T.Buffer(B_shape, dtypeAB), + C: T.Buffer((M, N), dtypeC), + ): + with T.Kernel(streamk_programs, threads=threads) as pid: + + A_shared = T.alloc_shared(A_shared_shape, dtypeAB) + B_shared = T.alloc_shared(B_shared_shape, dtypeAB) + A_shared_full_tiles = T.alloc_shared(A_shared_shape, dtypeAB) + B_shared_full_tiles = T.alloc_shared(B_shared_shape, dtypeAB) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + compute_first_wave(pid, A, A_shared, B, B_shared, C, C_local) + + if sm_patition_factor > 0: + compute_full_tiles(pid, A, A_shared_full_tiles, B, B_shared_full_tiles, C, C_local) + + return main + + +_tl_matmul_streamk = tl_matmul_streamk( + m, + n, + k, + streamk_tiles, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + False, + True, + "float16", + "float16", + "float32", + 2, + 64, +) + +kernel = tilelang.compile(_tl_matmul_streamk) +print(kernel.get_kernel_source()) + +b_c = torch.zeros((m, n), device="cuda", dtype=torch.float16) + +kernel(A, B, b_c) + +C = torch.matmul(A, B.T) + +print(b_c) +print(C) +torch.testing.assert_close(C, b_c, rtol=1e-2, atol=1e-2) From bd76b56e91a6fc6da4b175fe5066bc239a78d675 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 23 Feb 2025 18:05:43 +0000 Subject: [PATCH 03/11] Refactor GEMM SplitK and StreamK example implementations Clean up and improve code formatting for the SplitK and StreamK GEMM example scripts: - Remove unused import (Profiler) in splitk example - Simplify line breaks and improve code readability - Standardize indentation and remove unnecessary whitespace - Optimize atomic add and copy operations for better clarity --- .../example_tilelang_gemm_splitk.py | 16 ++++++-------- .../example_tilelang_gemm_streamk.py | 22 +++++++------------ 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/examples/gemm_splitk/example_tilelang_gemm_splitk.py b/examples/gemm_splitk/example_tilelang_gemm_splitk.py index e552320eb..dbf06b7ad 100644 --- a/examples/gemm_splitk/example_tilelang_gemm_splitk.py +++ b/examples/gemm_splitk/example_tilelang_gemm_splitk.py @@ -2,12 +2,12 @@ # Licensed under the MIT License. import tilelang -from tilelang import Profiler import tilelang.language as T from tvm import DataType + def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_dtype="float"): - + splitK = K // split_k @T.prim_func @@ -16,7 +16,8 @@ def main( B: T.Buffer((N, K), dtype), C: T.Buffer((M, N), dtype), ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz): A_shared = T.alloc_shared((block_M, block_K), dtype, "shared") B_shared = T.alloc_shared((block_K, block_N), dtype, "shared") C_shared = T.alloc_shared((block_M, block_N), dtype, "shared") @@ -33,20 +34,17 @@ def main( T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared) T.copy(B[bz * splitK + ko * block_K, bx * block_N], B_shared) T.gemm(A_shared, B_shared, C_local) - + T.copy(C_local, C_shared) if DataType(dtype).bits == 16: for i, j in T.Parallel(block_M, block_N // 2): m, n = by * block_M + i, bx * block_N + j * 2 # vectorized atomic - T.atomic_addx2( - C[m, n], C_shared[i, j * 2]) + T.atomic_addx2(C[m, n], C_shared[i, j * 2]) else: for i, j in T.Parallel(block_M, block_N): - T.atomic_add( - C[by * block_M + i, bx * block_N + j], C_shared[i, j]) - + T.atomic_add(C[by * block_M + i, bx * block_N + j], C_shared[i, j]) return main diff --git a/examples/gemm_streamk/example_tilelang_gemm_streamk.py b/examples/gemm_streamk/example_tilelang_gemm_streamk.py index df7be58a8..84a3e650a 100644 --- a/examples/gemm_streamk/example_tilelang_gemm_streamk.py +++ b/examples/gemm_streamk/example_tilelang_gemm_streamk.py @@ -39,9 +39,7 @@ def cdiv(a, b): # Two-tile SK + DP streamk_tiles = total_tiles % streamk_programs -if ( - total_tiles - streamk_tiles > streamk_programs -): # (total_tiles // total_programs > 1) +if (total_tiles - streamk_tiles > streamk_programs): # (total_tiles // total_programs > 1) streamk_tiles += streamk_programs blocking_tiles = total_tiles - streamk_tiles @@ -92,9 +90,7 @@ def compute_first_wave( end_iter = T.alloc_fragment((1,), "int32", "local") start_iter[0] = pid * streamk_full_tiles + T.min(pid, streamk_partial_tiles) - last_iter = (pid + 1) * streamk_full_tiles + T.min( - pid + 1, streamk_partial_tiles - ) + last_iter = (pid + 1) * streamk_full_tiles + T.min(pid + 1, streamk_partial_tiles) while start_iter[0] < last_iter: end_iter[0] = T.min( @@ -124,9 +120,7 @@ def compute_first_wave( T.copy(C_local, C[pid_m * block_M, pid_n * block_N]) else: for i, j in T.Parallel(block_M, block_N): - T.atomic_add( - C[pid_m * block_M + i, pid_n * block_N + j], C_local[i, j] - ) + T.atomic_add(C[pid_m * block_M + i, pid_n * block_N + j], C_local[i, j]) start_iter[0] = end_iter[0] @@ -155,9 +149,9 @@ def compute_full_tiles( @T.prim_func def main( - A: T.Buffer(A_shape, dtypeAB), - B: T.Buffer(B_shape, dtypeAB), - C: T.Buffer((M, N), dtypeC), + A: T.Buffer(A_shape, dtypeAB), + B: T.Buffer(B_shape, dtypeAB), + C: T.Buffer((M, N), dtypeC), ): with T.Kernel(streamk_programs, threads=threads) as pid: @@ -166,9 +160,9 @@ def main( A_shared_full_tiles = T.alloc_shared(A_shared_shape, dtypeAB) B_shared_full_tiles = T.alloc_shared(B_shared_shape, dtypeAB) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - + compute_first_wave(pid, A, A_shared, B, B_shared, C, C_local) - + if sm_patition_factor > 0: compute_full_tiles(pid, A, A_shared_full_tiles, B, B_shared_full_tiles, C, C_local) From c2cffd9afd38c4f0869ef313e591575bcc19a44c Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 24 Feb 2025 09:19:17 +0000 Subject: [PATCH 04/11] Add block sparse attention benchmarks for multiple libraries This commit introduces comprehensive block sparse attention benchmarks for different libraries: - TileLang block sparse FMHA implementation - Triton block sparse FMHA implementation - PyTorch reference block sparse FMHA implementation - FlashAttention dense FMHA reference implementation The benchmarks include: - Configurable benchmark parameters (batch size, heads, sequence length, etc.) - Sparse mask generation using top-k and threshold methods - Performance measurement for different sparse attention configurations - Utility functions for mask generation and benchmarking --- .../benchmark_configs.py | 2 + .../benchmark_library_dense_fmha.py | 57 ++++ .../benchmark_tilelang_block_sparse_fmha.py | 213 ++++++++++++ .../benchmark_torch_block_sparse_fmha.py | 78 +++++ .../benchmark_triton_block_sparse_fmha.py | 308 ++++++++++++++++++ .../blocksparse_attention/requirements.txt | 1 + 6 files changed, 659 insertions(+) create mode 100644 benchmark/blocksparse_attention/benchmark_configs.py create mode 100644 benchmark/blocksparse_attention/benchmark_library_dense_fmha.py create mode 100644 benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py create mode 100644 benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py create mode 100644 benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py create mode 100644 benchmark/blocksparse_attention/requirements.txt diff --git a/benchmark/blocksparse_attention/benchmark_configs.py b/benchmark/blocksparse_attention/benchmark_configs.py new file mode 100644 index 000000000..a23e2136a --- /dev/null +++ b/benchmark/blocksparse_attention/benchmark_configs.py @@ -0,0 +1,2 @@ +# BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK +configs = [[4, 2, 256, 64, 2, 64]] diff --git a/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py new file mode 100644 index 000000000..b1bbcbed2 --- /dev/null +++ b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py @@ -0,0 +1,57 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import math +import torch + +import torch.nn.functional as F +from tilelang.profiler import do_bench + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], + False, + dtype=torch.bool, + device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def benchmark_topk_sparse_attention(): + from benchmark_configs import configs + torch.manual_seed(0) + + # Config + for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: + + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + + import flash_attn + def benchmark_fn(): + flash_attn.flash_attn_func(q, k, v, causal=True) + + ref_latency = do_bench( + benchmark_fn, + warmup=10, + rep=100, + ) + print(f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}") + + +if __name__ == "__main__": + benchmark_topk_sparse_attention() diff --git a/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py new file mode 100644 index 000000000..64a44e5b5 --- /dev/null +++ b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py @@ -0,0 +1,213 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import math +import torch + +import tilelang +from tilelang import language as T +from tilelang.profiler import do_bench + +def is_hip(): + return False + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], + False, + dtype=torch.bool, + device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal): + block_M = 64 + block_N = 64 + num_stages = 0 + threads = 128 + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + shape = [batch, heads, seq_len, dim] + block_mask_shape = [batch, heads, downsample_len, downsample_len] + + dtype = "float16" + accum_dtype = "float" + block_mask_dtype = "int8" + + def kernel_func(block_M, block_N, num_stages, threads): + + @T.macro + def MMA0( + K: T.Buffer(shape, dtype), + Q_shared: T.Buffer([block_M, dim], dtype), + K_shared: T.Buffer([block_N, dim], dtype), + acc_s: T.Buffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, + -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Buffer(shape, dtype), + V_shared: T.Buffer([block_M, dim], dtype), + acc_s_cast: T.Buffer([block_M, block_N], dtype), + acc_o: T.Buffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.Buffer([block_M, block_N], accum_dtype), + acc_s_cast: T.Buffer([block_M, block_N], dtype), + scores_max: T.Buffer([block_M], accum_dtype), + scores_max_prev: T.Buffer([block_M], accum_dtype), + scores_scale: T.Buffer([block_M], accum_dtype), + scores_sum: T.Buffer([block_M], accum_dtype), + logsum: T.Buffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.Buffer([block_M, dim], accum_dtype), + scores_scale: T.Buffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Buffer(shape, dtype), + K: T.Buffer(shape, dtype), + V: T.Buffer(shape, dtype), + BlockSparseMask: T.Buffer(block_mask_shape, block_mask_dtype), + Output: T.Buffer(shape, dtype), + ): + with T.Kernel( + T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + block_mask = T.alloc_local([downsample_len], block_mask_dtype) + + T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + for vj in T.serial(downsample_len): + block_mask[vj] = BlockSparseMask[bz, by, bx, vj] + + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv( + (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + if block_mask[k] != 0: + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, + scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + + return main + + return kernel_func(block_M, block_N, num_stages, threads) + + +def benchmark_topk_sparse_attention(): + from benchmark_configs import configs + torch.manual_seed(0) + + # Config + for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: + + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + + sm_scale = 1.0 / (D_HEAD**0.5) + + # Create sparse mask (downsampled to block level) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], + device='cuda', + dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + kernel = tilelang.compile(program, out_idx=4) + def benchmark_fn(): + # Compute reference + # Expand block mask to full attention matrix + kernel(q, k, v, block_mask) + + ref_latency = do_bench( + benchmark_fn, + warmup=10, + rep=100, + ) + print(f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}") + + +if __name__ == "__main__": + benchmark_topk_sparse_attention() diff --git a/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py new file mode 100644 index 000000000..86dbef000 --- /dev/null +++ b/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import math +import torch + +import torch.nn.functional as F +from tilelang.profiler import do_bench + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], + False, + dtype=torch.bool, + device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def benchmark_topk_sparse_attention(): + from benchmark_configs import configs + torch.manual_seed(0) + + # Config + for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: + + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + + sm_scale = 1.0 / (D_HEAD**0.5) + + # Create sparse mask (downsampled to block level) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], + device='cuda', + dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + + def benchmark_fn(): + # Compute reference + # Expand block mask to full attention matrix + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')) + full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() + full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) # Apply causal + + # PyTorch reference implementation + attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale + attn = attn.masked_fill(~full_mask, float('-inf')) + attn = F.softmax(attn, dim=-1) + ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) + return ref_output + + ref_latency = do_bench( + benchmark_fn, + warmup=10, + rep=100, + ) + print(f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}") + + +if __name__ == "__main__": + benchmark_topk_sparse_attention() diff --git a/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py new file mode 100644 index 000000000..816ede239 --- /dev/null +++ b/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py @@ -0,0 +1,308 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import math +import torch + +import triton +import triton.language as tl +from tilelang.profiler import do_bench + +def is_hip(): + return False + +def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): + bsz, num_head, downsample_len, _ = x.shape + # N_CTX = downsample_len * BLOCK + sparse_index = torch.topk(x, topk, dim=-1).indices + dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], + False, + dtype=torch.bool, + device=x.device) + dense_mask.scatter_(-1, sparse_index, True) + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): + dense_mask = x > threshold + if use_dense_for_last_block: + dense_mask[:, :, -2:, :] = True + dense_mask.tril_() + return dense_mask + + +@triton.jit +def _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + k_block_col_idx, + block_mask_ptr, + k_ptrs, + v_ptrs, + offs_m, + offs_n, + stride_kt, + stride_vt, + stride_bmask_n, + sm_scale, + seqlen_k, + past_len, + LAST_K_BLOCK: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + + mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) + + if mask_val == True: + start_n = k_block_col_idx * BLOCK_N + # -- compute qk ---- + + k = tl.load(k_ptrs + start_n * stride_kt) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + + qk *= sm_scale + + # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N + if LAST_K_BLOCK: + qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, + float('-inf')) + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.exp(qk) + l_ij = tl.sum(p, 1) + alpha = tl.exp(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + + # update acc + v = tl.load(v_ptrs + start_n * stride_vt) + + p = p.to(v.type.element_ty) + + acc += tl.dot(p, v) + # update m_i and l_i + m_i = m_ij + return acc, l_i, m_i + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + block_mask_ptr, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qd, + stride_kz, + stride_kh, + stride_kn, + stride_kd, + stride_vz, + stride_vh, + stride_vn, + stride_vd, + stride_bmz, + stride_bmh, + stride_bmm, + stride_bmn, + stride_oz, + stride_oh, + stride_om, + stride_od, + H, + N_CTX, + PAST_LEN, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + Q_LEN = N_CTX - PAST_LEN + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_h = off_hz % H + off_z = off_hz // H + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_h * stride_kh + V += off_z * stride_vz + off_h * stride_vh + block_mask_ptr += off_z * stride_bmz + off_h * stride_bmh + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + # off_k = offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd + off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + mask_ptrs = block_mask_ptr + start_m * stride_bmm + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN) + + k_block_start = 0 + k_block_end = tl.cdiv((start_m + 1) * BLOCK_M, BLOCK_N) + + # loop over k, v and update accumulator + for col_idx in range(k_block_start, k_block_end): + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + col_idx, + mask_ptrs, + k_ptrs, + v_ptrs, + offs_m, + offs_n, + stride_kn, + stride_vn, + stride_bmn, + sm_scale, + N_CTX, + PAST_LEN, + col_idx == k_block_end - 1, + BLOCK_M, + BLOCK_N, + ) + + m_i += tl.math.log(l_i) + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + acc = acc.to(Out.dtype.element_ty) + + off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[ + None, :] * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) + + +def _forward(ctx, + q, + k, + v, + block_sparse_mask, + sm_scale, + BLOCK_M=64, + BLOCK_N=64, + num_warps=None, + num_stages=1, + out=None): + + assert q.shape[-1] == k.shape[-1] == v.shape[-1] + assert k.shape[2] == v.shape[2] + o = out if out is not None else torch.empty_like(q).contiguous() + grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1]) + + assert q.shape[-1] in [64, 128] + BLOCK_DMODEL = q.shape[-1] + + if is_hip(): + num_warps, num_stages = 8, 1 + else: + num_warps, num_stages = 4, 2 + + N_CTX = k.shape[2] + PAST_LEN = N_CTX - q.shape[2] + + H = q.shape[1] + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + block_sparse_mask, + o, + *q.stride(), + *k.stride(), + *v.stride(), + *block_sparse_mask.stride(), + *o.stride(), + H, + N_CTX, + PAST_LEN, + BLOCK_M, + BLOCK_N, + BLOCK_DMODEL, + num_warps=num_warps, + num_stages=num_stages, + ) + + return o + + +class _sparse_attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, block_sparse_dense, sm_scale): + # shape constraints + return _forward(ctx, q, k, v, block_sparse_dense, sm_scale) + + @staticmethod + def backward(ctx, do): + # No gradient propagation. + raise NotImplementedError("It does not support gradient propagation yet") + return None, None, None, None, None + + +block_sparse_triton_fn = _sparse_attention.apply + + +def benchmark_topk_sparse_attention(): + from benchmark_configs import configs + torch.manual_seed(0) + + # Config + for BATCH, N_HEADS, SEQ_LEN, D_HEAD, TOPK, BLOCK in configs: + + # Create inputs + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + + sm_scale = 1.0 / (D_HEAD**0.5) + + # Create sparse mask (downsampled to block level) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], + device='cuda', + dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + + def benchmark_fn(): + # Compute reference + # Expand block mask to full attention matrix + block_sparse_triton_fn(q, k, v, block_mask, sm_scale) + + ref_latency = do_bench( + benchmark_fn, + warmup=10, + rep=100, + ) + print(f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}") + + +if __name__ == "__main__": + benchmark_topk_sparse_attention() diff --git a/benchmark/blocksparse_attention/requirements.txt b/benchmark/blocksparse_attention/requirements.txt new file mode 100644 index 000000000..d0f509936 --- /dev/null +++ b/benchmark/blocksparse_attention/requirements.txt @@ -0,0 +1 @@ +flash-attn From 789781cf3143f9cbe69f093941f764bc254fb678 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 24 Feb 2025 09:40:27 +0000 Subject: [PATCH 05/11] Refactor block sparse attention benchmarks with code style improvements - Add Ruff linter ignore comments to benchmark files - Improve code formatting and line breaks - Remove unused imports - Standardize print statement formatting - Enhance code readability across multiple library benchmarks --- .../benchmark_library_dense_fmha.py | 9 +++++---- .../benchmark_tilelang_block_sparse_fmha.py | 15 +++++++++++---- .../benchmark_torch_block_sparse_fmha.py | 12 ++++++++---- .../benchmark_triton_block_sparse_fmha.py | 15 ++++++++++----- 4 files changed, 34 insertions(+), 17 deletions(-) diff --git a/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py index b1bbcbed2..f3759a594 100644 --- a/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py @@ -1,9 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import math +# ruff: noqa import torch - -import torch.nn.functional as F from tilelang.profiler import do_bench def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): @@ -42,6 +40,7 @@ def benchmark_topk_sparse_attention(): v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) import flash_attn + def benchmark_fn(): flash_attn.flash_attn_func(q, k, v, causal=True) @@ -50,7 +49,9 @@ def benchmark_fn(): warmup=10, rep=100, ) - print(f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}") + print( + f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}" + ) if __name__ == "__main__": diff --git a/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py index 64a44e5b5..e3fec73fb 100644 --- a/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# ruff: noqa import math import torch @@ -7,9 +8,11 @@ from tilelang import language as T from tilelang.profiler import do_bench + def is_hip(): return False + def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK @@ -190,12 +193,14 @@ def benchmark_topk_sparse_attention(): downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + device='cuda', + dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) - program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + program = blocksparse_flashattn( + BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) kernel = tilelang.compile(program, out_idx=4) + def benchmark_fn(): # Compute reference # Expand block mask to full attention matrix @@ -206,7 +211,9 @@ def benchmark_fn(): warmup=10, rep=100, ) - print(f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}") + print( + f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}" + ) if __name__ == "__main__": diff --git a/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py index 86dbef000..66d3ec8f7 100644 --- a/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py @@ -1,11 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# ruff: noqa import math import torch import torch.nn.functional as F from tilelang.profiler import do_bench + def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK @@ -47,11 +49,11 @@ def benchmark_topk_sparse_attention(): downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + device='cuda', + dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) - + def benchmark_fn(): # Compute reference # Expand block mask to full attention matrix @@ -71,7 +73,9 @@ def benchmark_fn(): warmup=10, rep=100, ) - print(f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}") + print( + f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}" + ) if __name__ == "__main__": diff --git a/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py b/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py index 816ede239..e86590c5e 100644 --- a/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# ruff: noqa import math import torch @@ -7,9 +8,11 @@ import triton.language as tl from tilelang.profiler import do_bench + def is_hip(): return False + def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK @@ -286,22 +289,24 @@ def benchmark_topk_sparse_attention(): downsample_factor = BLOCK downsample_len = math.ceil(SEQ_LEN / downsample_factor) x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device='cuda', - dtype=torch.bfloat16) + device='cuda', + dtype=torch.bfloat16) x_ds[:, :, :, 0] = 100 block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) - + def benchmark_fn(): # Compute reference # Expand block mask to full attention matrix - block_sparse_triton_fn(q, k, v, block_mask, sm_scale) + block_sparse_triton_fn(q, k, v, block_mask, sm_scale) # noqa: B023 ref_latency = do_bench( benchmark_fn, warmup=10, rep=100, ) - print(f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}") + print( + f"BATCH: {BATCH}, N_HEADS: {N_HEADS}, SEQ_LEN: {SEQ_LEN}, D_HEAD: {D_HEAD}, TOPK: {TOPK}, BLOCK: {BLOCK}, ref_latency: {ref_latency}" + ) if __name__ == "__main__": From 166ef78b236d94e365198c5285f0dd223813e9c6 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 24 Feb 2025 09:40:42 +0000 Subject: [PATCH 06/11] lint fix --- benchmark/blocksparse_attention/benchmark_library_dense_fmha.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py index f3759a594..f22a3a910 100644 --- a/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py +++ b/benchmark/blocksparse_attention/benchmark_library_dense_fmha.py @@ -4,6 +4,7 @@ import torch from tilelang.profiler import do_bench + def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): bsz, num_head, downsample_len, _ = x.shape # N_CTX = downsample_len * BLOCK From 83288854a131f2e85843647cc505d5aa142447e8 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 24 Feb 2025 16:38:01 +0000 Subject: [PATCH 07/11] Add CUDA atomic operations for BFLOAT16 and update function naming - Implement AtomicAdd functions for BFLOAT16 and BFLOAT16x2 in CUDA common header - Rename existing atomic add functions to use PascalCase (atomicAdd -> AtomicAdd) - Add a new __pack_nv_bfloat162 function for packing BFLOAT16 values - Update kernel and language customization to use new function names - Add return type annotations in profiler module --- src/tl_templates/cuda/common.h | 37 +++++++++++++++++++++++++++++----- tilelang/jit/kernel.py | 2 +- tilelang/language/customize.py | 4 ++-- tilelang/profiler/__init__.py | 8 ++++++-- 4 files changed, 41 insertions(+), 10 deletions(-) diff --git a/src/tl_templates/cuda/common.h b/src/tl_templates/cuda/common.h index ad6fef1e6..1fa4eb2e3 100644 --- a/src/tl_templates/cuda/common.h +++ b/src/tl_templates/cuda/common.h @@ -46,6 +46,13 @@ TL_DEVICE unsigned __pack_half2(const bfloat16_t x, const bfloat16_t y) { return (v1 << 16) | v0; } +// Pack two bfloat16_t values. +TL_DEVICE unsigned __pack_nv_bfloat162(const bfloat16_t x, const bfloat16_t y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + // Pack four char values TL_DEVICE int make_int(signed char x0, signed char x1, signed char x2, signed char x3) { @@ -83,27 +90,47 @@ TL_DEVICE unsigned int cast_smem_ptr_to_int(const void *const smem_ptr) { } // AtomicAdd Functions for FP16 -TL_DEVICE void atomicAdd(half_t *address, half_t val) { +TL_DEVICE void AtomicAdd(half_t *address, half_t val) { // Use atomicCAS with built-in cuda_fp16 support atomicAdd(reinterpret_cast(address), static_cast(val)); } // AtomicAdd Functions for FP16 -TL_DEVICE void atomicAdd(half_t *address, half_t *val) { +TL_DEVICE void AtomicAdd(half_t *address, half_t *val) { atomicAdd(reinterpret_cast(address), static_cast(*val)); } -// AtomicAdd Functions for FP16 -TL_DEVICE void atomicAddx2(half_t *address, half_t *val) { +// AtomicAdd Functions for FP16x2 +TL_DEVICE void AtomicAddx2(half_t *address, half_t *val) { atomicAdd(reinterpret_cast(address), static_cast(*reinterpret_cast(val))); } -TL_DEVICE void atomicAdd(half_t *address, float val) { +// AtomicAdd Functions for FP16 +TL_DEVICE void AtomicAdd(half_t *address, float val) { // Use atomicCAS with built-in cuda_fp16 support atomicAdd(reinterpret_cast(address), __float2half(val)); } +// AtomicAdd Functions for BFLOAT16 +TL_DEVICE void AtomicAdd(bfloat16_t *address, bfloat16_t *val) { + atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address), static_cast<__nv_bfloat16>(*val)); +} + +TL_DEVICE void AtomicAdd(bfloat16_t *address, float val) { + atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address), __float2bfloat16(val)); +} + +TL_DEVICE void AtomicAdd(bfloat16_t *address, bfloat16_t val) { + atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address), static_cast<__nv_bfloat16>(val)); +} + +// AtomicAdd Functions for BFLOAT16x2 +TL_DEVICE void AtomicAddx2(bfloat16_t *address, bfloat16_t *val) { + atomicAdd(reinterpret_cast<__nv_bfloat162 *>(address), + static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val))); +} + // DP4A template TL_DEVICE void DP4A(InDatatype *a, InDatatype *b, OutDatatype *c) { diff --git a/tilelang/jit/kernel.py b/tilelang/jit/kernel.py index 09e7b4453..ab9356e0d 100644 --- a/tilelang/jit/kernel.py +++ b/tilelang/jit/kernel.py @@ -206,7 +206,7 @@ def get_kernel_source(self) -> str: str The source code of the compiled kernel function. """ - if self.execution_backend == "ctypes": + if self.execution_backend in {"ctypes", "cython"}: return self.adapter.get_kernel_source() return self.rt_module.imported_modules[0].get_source() diff --git a/tilelang/language/customize.py b/tilelang/language/customize.py index 0be467f0a..40d93a6d3 100644 --- a/tilelang/language/customize.py +++ b/tilelang/language/customize.py @@ -6,11 +6,11 @@ def atomic_add(dst, value): - return T.call_extern("handle", "atomicAdd", T.address_of(dst), value) + return T.call_extern("handle", "AtomicAdd", T.address_of(dst), value) def atomic_addx2(dst, value): - return T.call_extern("handle", "atomicAddx2", T.address_of(dst), T.address_of(value)) + return T.call_extern("handle", "AtomicAddx2", T.address_of(dst), T.address_of(value)) def dp4a(A, B, C): diff --git a/tilelang/profiler/__init__.py b/tilelang/profiler/__init__.py index e1544b961..59f0e90c5 100644 --- a/tilelang/profiler/__init__.py +++ b/tilelang/profiler/__init__.py @@ -112,7 +112,7 @@ def do_bench( n_repeat: int = 1, profiler: Literal["torch", "tvm", "auto"] = "auto", input_tensors: List[torch.Tensor] = None, - ): + ) -> float: profiler = self.determine_profiler(func, profiler) if profiler == "torch": ins = self._get_inputs() if input_tensors is None else input_tensors @@ -156,7 +156,7 @@ def do_bench( quantiles=None, fast_flush=True, return_mode="mean", -): +) -> float: """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -173,6 +173,10 @@ def do_bench( :type quantiles: list[float] :param fast_flush: Use faster kernel to flush L2 between measurements :type fast_flush: bool + + Returns: + float: The median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. """ assert return_mode in ["min", "max", "mean", "median"] fn() From 151072e9dc81683e2c9862feca35bae4490cc2df Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 24 Feb 2025 16:38:16 +0000 Subject: [PATCH 08/11] lint fix --- src/tl_templates/cuda/common.h | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/tl_templates/cuda/common.h b/src/tl_templates/cuda/common.h index 1fa4eb2e3..139c0d200 100644 --- a/src/tl_templates/cuda/common.h +++ b/src/tl_templates/cuda/common.h @@ -114,7 +114,8 @@ TL_DEVICE void AtomicAdd(half_t *address, float val) { // AtomicAdd Functions for BFLOAT16 TL_DEVICE void AtomicAdd(bfloat16_t *address, bfloat16_t *val) { - atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address), static_cast<__nv_bfloat16>(*val)); + atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address), + static_cast<__nv_bfloat16>(*val)); } TL_DEVICE void AtomicAdd(bfloat16_t *address, float val) { @@ -122,13 +123,15 @@ TL_DEVICE void AtomicAdd(bfloat16_t *address, float val) { } TL_DEVICE void AtomicAdd(bfloat16_t *address, bfloat16_t val) { - atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address), static_cast<__nv_bfloat16>(val)); + atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address), + static_cast<__nv_bfloat16>(val)); } // AtomicAdd Functions for BFLOAT16x2 TL_DEVICE void AtomicAddx2(bfloat16_t *address, bfloat16_t *val) { - atomicAdd(reinterpret_cast<__nv_bfloat162 *>(address), - static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val))); + atomicAdd( + reinterpret_cast<__nv_bfloat162 *>(address), + static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val))); } // DP4A From ebb5a29efaa85b4b49b6bbd63e80ab21642f32fa Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 25 Feb 2025 05:07:45 +0000 Subject: [PATCH 09/11] Add example for Group Query Attention (GQA) forward pass using Flash Attention in TileLang This commit introduces a new example script `example_gqa_fwd_bshd.py` that demonstrates: - Group Query Attention (GQA) implementation - Flash Attention forward pass - Performance benchmarking - Configurable parameters for batch, heads, sequence length, and dimension - Autotuning support - Reference implementation comparison --- .../flash_attention/example_gqa_fwd_bshd.py | 241 ++++++++++++++++++ 1 file changed, 241 insertions(+) create mode 100644 examples/flash_attention/example_gqa_fwd_bshd.py diff --git a/examples/flash_attention/example_gqa_fwd_bshd.py b/examples/flash_attention/example_gqa_fwd_bshd.py new file mode 100644 index 000000000..845cee648 --- /dev/null +++ b/examples/flash_attention/example_gqa_fwd_bshd.py @@ -0,0 +1,241 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.nn.functional as F +import tilelang +from tilelang import Profiler +from tilelang.autotuner import * +import tilelang.language as T +import itertools +import argparse +from functools import partial + + +def get_configs(): + block_M = [128] + block_N = [128] + num_stages = [2] + threads = [256] + _configs = list(itertools.product(block_M, block_N, num_stages, threads)) + + configs = [{ + 'block_M': c[0], + 'block_N': c[1], + 'num_stages': c[2], + 'threads': c[3] + } for c in _configs] + return configs + + +def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + dtype = "float16" + accum_dtype = "float" + + def kernel_func(block_M, block_N, num_stages, threads): + + @T.macro + def MMA0( + K: T.Buffer(kv_shape, dtype), + Q_shared: T.Buffer([block_M, dim], dtype), + K_shared: T.Buffer([block_N, dim], dtype), + acc_s: T.Buffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, + -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Buffer(kv_shape, dtype), + V_shared: T.Buffer([block_M, dim], dtype), + acc_s_cast: T.Buffer([block_M, block_N], dtype), + acc_o: T.Buffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.Buffer([block_M, block_N], accum_dtype), + acc_s_cast: T.Buffer([block_M, block_N], dtype), + scores_max: T.Buffer([block_M], accum_dtype), + scores_max_prev: T.Buffer([block_M], accum_dtype), + scores_scale: T.Buffer([block_M], accum_dtype), + scores_sum: T.Buffer([block_M], accum_dtype), + logsum: T.Buffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.Buffer([block_M, dim], accum_dtype), + scores_scale: T.Buffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Buffer(q_shape, dtype), + K: T.Buffer(kv_shape, dtype), + V: T.Buffer(kv_shape, dtype), + Output: T.Buffer(q_shape, dtype), + ): + with T.Kernel( + T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv( + (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, + scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + + return main + + if tune: + + @autotune( + configs=get_configs(), + keys=["block_M", "block_N", "num_stages", "threads"], + warmup=10, + rep=10) + @jit( + out_idx=[3], + supply_type=tilelang.TensorSupplyType.Integer, + ref_prog=None, + profiler="auto") + def kernel(block_M=None, block_N=None, num_stages=None, threads=None): + return kernel_func(block_M, block_N, num_stages, threads) + + return kernel() + else: + + def kernel(block_M, block_N, num_stages, threads): + return kernel_func(block_M, block_N, num_stages, threads) + + return kernel + + +def ref_program(Q, K, V, is_causal, groups=1): + # Q: [B, T, HQ, D] + # K: [B, T, HK, D] + # V: [B, T, HV, D] + # HQ = HKV * groups + assert Q.size(2) == K.size( + 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size( + 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + + dim = Q.size(-1) + K = K.repeat_interleave(groups, dim=2) + V = V.repeat_interleave(groups, dim=2) + scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float('-inf')) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + return output + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch', type=int, default=8, help='batch size') + parser.add_argument('--heads', type=int, default=32, help='heads') + parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') + parser.add_argument('--dim', type=int, default=128, help='dim') + parser.add_argument('--is_causal', action='store_true', help='causal') + parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument('--groups', type=int, default=8, help='groups') + args = parser.parse_args() + batch, heads, seq_len, dim, is_causal, groups = args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + if (not args.tune): + program = flashattn( + batch, heads, seq_len, dim, is_causal, tune=args.tune, groups=groups)( + block_M=128, block_N=128, num_stages=1, threads=128) + ref_program = partial(ref_program, is_causal=is_causal, groups=groups) + mod, params = tilelang.lower(program) + mod = Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) + mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = mod.do_bench(ref_program, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = mod.do_bench(mod.func, warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + best_latency, best_config, _ = flashattn( + batch, heads, seq_len, dim, is_causal, tune=args.tune) + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") From 280cef19b5fdac4b9976a91bf7609d170d86bd67 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 25 Feb 2025 06:19:36 +0000 Subject: [PATCH 10/11] Refactor IR lowering pipeline into modular phases This commit introduces a new module `phase.py` to modularize the IR lowering process by splitting the complex lowering pipeline into two distinct phases: - `LowerAndLegalize`: Handles initial IR legalization and transformation - `OptimizeForTarget`: Applies target-specific optimizations The changes simplify the lowering logic in multiple files by extracting the transformation steps into reusable functions, improving code readability and maintainability. --- tilelang/engine/lower.py | 73 ++++-------------------- tilelang/engine/phase.py | 84 ++++++++++++++++++++++++++++ tilelang/jit/adapter/ctypes/utils.py | 59 ++----------------- tilelang/jit/adapter/cython/utils.py | 57 ++----------------- 4 files changed, 106 insertions(+), 167 deletions(-) create mode 100644 tilelang/engine/phase.py diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index 964262cfd..f33c49e9e 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -12,7 +12,10 @@ from tvm.target import Target from tilelang.contrib import hipcc, nvcc from tilelang.utils.target import determine_target - +from tilelang.engine.phase import ( + LowerAndLegalize, + OptimizeForTarget, +) def is_cpu_device_backend(target: Target): return target.kind.name == "c" @@ -152,68 +155,12 @@ def lower( _is_host_call = get_host_call(is_device_c=is_cpu_device_backend(target)) _is_device_call = get_device_call(is_device_c=is_cpu_device_backend(target)) - mod = tir.transform.BindTarget(target)(mod) - - mod = tl.transform.FrontendLegalize()(mod) - mod = tir.transform.Simplify()(mod) - mod = tl.transform.LayoutInference()(mod) - mod = tl.transform.LowerTileOp()(mod) - mod = tl.transform.LegalizeVectorizedLoop()(mod) - mod = tl.transform.LegalizeSafeMemoryAccess()(mod) - # Inject Simplify to remove the duplicated conditions - mod = tir.transform.Simplify()(mod) - - # which may be introduced by the LegalizeSafeMemoryAccess - if target.arch == "sm_90": - mod = tl.transform.MultiVersionBuffer()(mod) - mod = tl.transform.WarpSpecialized()(mod) - mod = tl.transform.InjectSoftwarePipeline()(mod) - mod = tir.transform.LowerOpaqueBlock()(mod) - # mod = tl.transform.WarpSpecializedPipeline()(mod) - mod = tl.transform.InjectFenceProxy()(mod) - else: - mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) - mod = tl.transform.PipelinePlanning()(mod) - mod = tl.transform.InjectSoftwarePipeline()(mod) - - mod = tir.transform.LowerOpaqueBlock()(mod) - mod = tir.transform.FlattenBuffer()(mod) - mod = tir.transform.NarrowDataType(32)(mod) - mod = tir.transform.Simplify()(mod) - mod = tl.transform.VectorizeLoop()(mod) - mod = tir.transform.StorageRewrite()(mod) - mod = tir.transform.UnrollLoop()(mod) - mod = tir.transform.RenormalizeSplitPattern()(mod) - mod = tir.transform.Simplify()(mod) - mod = tir.transform.RemoveNoOp()(mod) - mod = tir.transform.RewriteUnsafeSelect()(mod) - mod = tir.transform.HoistIfThenElse()(mod) - - mod = tir.transform.VerifyMemory()(mod) - mod = tir.transform.AnnotateEntryFunc()(mod) - # TODO(lei): This is a hack to make sure the - # thread level allreduce pass can be applied - # in TL. As Tl only use one thread dimension - # the var binding information will be lost - # in the lowering process with Legalization - # and Simplify pass. - # We can find a way better to create var instead - # of putting the LowerThreadAllreduce before - # the Legalization. - mod = tl.transform.ThreadPartialSync("shared.dyn")(mod) - mod = tir.transform.InferFragment()(mod) - mod = tir.transform.LowerThreadAllreduce()(mod) - mod = tl.transform.LowerHopperIntrin()(mod) - mod = tl.transform.ThreadSync("shared")(mod) - mod = tl.transform.ThreadSync("shared.dyn")(mod) - mod = tir.transform.InjectPTXAsyncCopy()(mod) - - mod = tl.transform.AnnotateDeviceRegions()(mod) - mod = tir.transform.SplitHostDevice()(mod) - mod = tir.transform.MergeSharedMemoryAllocations()(mod) - - mod = tl.transform.MakePackedAPI()(mod) - mod = tir.transform.LowerDeviceKernelLaunch()(mod) + # Phase 1: Lower and legalize the IR + mod = LowerAndLegalize(mod, target) + + # Phase 2: Optimize the IR for the target + mod = OptimizeForTarget(mod, target) + host_mod = tir.transform.Filter(_is_host_call)(mod) host_mod = tir.transform.BindTarget(target_host)(host_mod) host_mod = tir.transform.FP8StorageLegalize()(host_mod) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py new file mode 100644 index 000000000..3a94b5680 --- /dev/null +++ b/tilelang/engine/phase.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from tvm import tir, IRModule +from tvm.target import Target +import tilelang as tl + +def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: + # Bind the target device information to the module + mod = tir.transform.BindTarget(target)(mod) + + # Legalize the frontend IR to make it compatible with TVM + mod = tl.transform.FrontendLegalize()(mod) + # Simplify the IR expressions + mod = tir.transform.Simplify()(mod) + # Infer memory layouts for fragments and shared memory + mod = tl.transform.LayoutInference()(mod) + # Lower high-level tile operations to low-level operations + mod = tl.transform.LowerTileOp()(mod) + # Legalize vectorized loops to ensure they are valid + mod = tl.transform.LegalizeVectorizedLoop()(mod) + # Add safety checks for memory accesses + mod = tl.transform.LegalizeSafeMemoryAccess()(mod) + # Simplify again to clean up any duplicated conditions + # that may have been introduced by safety checks + mod = tir.transform.Simplify()(mod) + + return mod + + +def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: + # which may be introduced by the LegalizeSafeMemoryAccess + if target.arch == "sm_90": + mod = tl.transform.MultiVersionBuffer()(mod) + mod = tl.transform.WarpSpecialized()(mod) + mod = tl.transform.InjectSoftwarePipeline()(mod) + mod = tir.transform.LowerOpaqueBlock()(mod) + # mod = tl.transform.WarpSpecializedPipeline()(mod) + mod = tl.transform.InjectFenceProxy()(mod) + else: + mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) + mod = tl.transform.PipelinePlanning()(mod) + mod = tl.transform.InjectSoftwarePipeline()(mod) + + mod = tir.transform.LowerOpaqueBlock()(mod) + mod = tir.transform.FlattenBuffer()(mod) + mod = tir.transform.NarrowDataType(32)(mod) + mod = tir.transform.Simplify()(mod) + mod = tl.transform.VectorizeLoop()(mod) + mod = tir.transform.StorageRewrite()(mod) + mod = tir.transform.UnrollLoop()(mod) + mod = tir.transform.RenormalizeSplitPattern()(mod) + mod = tir.transform.Simplify()(mod) + mod = tir.transform.RemoveNoOp()(mod) + mod = tir.transform.RewriteUnsafeSelect()(mod) + mod = tir.transform.HoistIfThenElse()(mod) + + mod = tir.transform.VerifyMemory()(mod) + mod = tir.transform.AnnotateEntryFunc()(mod) + # TODO(lei): This is a hack to make sure the + # thread level allreduce pass can be applied + # in TL. As Tl only use one thread dimension + # the var binding information will be lost + # in the lowering process with Legalization + # and Simplify pass. + # We can find a way better to create var instead + # of putting the LowerThreadAllreduce before + # the Legalization. + mod = tl.transform.ThreadPartialSync("shared.dyn")(mod) + mod = tir.transform.InferFragment()(mod) + mod = tir.transform.LowerThreadAllreduce()(mod) + mod = tl.transform.LowerHopperIntrin()(mod) + mod = tl.transform.ThreadSync("shared")(mod) + mod = tl.transform.ThreadSync("shared.dyn")(mod) + mod = tir.transform.InjectPTXAsyncCopy()(mod) + + mod = tl.transform.AnnotateDeviceRegions()(mod) + mod = tir.transform.SplitHostDevice()(mod) + mod = tir.transform.MergeSharedMemoryAllocations()(mod) + + mod = tl.transform.MakePackedAPI()(mod) + mod = tir.transform.LowerDeviceKernelLaunch()(mod) + + return mod diff --git a/tilelang/jit/adapter/ctypes/utils.py b/tilelang/jit/adapter/ctypes/utils.py index 287d63c72..f2be9527a 100644 --- a/tilelang/jit/adapter/ctypes/utils.py +++ b/tilelang/jit/adapter/ctypes/utils.py @@ -5,12 +5,15 @@ from tilelang import tvm as tvm from tvm import IRModule, tir from tvm.target import Target -import tilelang.transform from tilelang.engine.lower import ( is_device_call, determine_target, canon_target_host, ) +from tilelang.engine.phase import ( + LowerAndLegalize, + OptimizeForTarget, +) def match_global_kernel(source: str) -> int: @@ -47,58 +50,8 @@ def get_annotated_device_mod( target_host = tvm.target.Target.canon_target(target_host) target = tvm.target.Target(target, target_host) - mod = tir.transform.BindTarget(target)(mod) - - mod = tilelang.transform.FrontendLegalize()(mod) - mod = tir.transform.Simplify()(mod) - mod = tilelang.transform.LayoutInference()(mod) - mod = tilelang.transform.LowerTileOp()(mod) - mod = tir.transform.Simplify()(mod) - - if target.arch == "sm_90": - mod = tilelang.transform.WarpSpecializedPipeline()(mod) - else: - mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) - mod = tilelang.transform.PipelinePlanning()(mod) - mod = tilelang.transform.InjectSoftwarePipeline()(mod) - - mod = tir.transform.LowerOpaqueBlock()(mod) - mod = tir.transform.FlattenBuffer()(mod) - mod = tir.transform.NarrowDataType(32)(mod) - mod = tir.transform.Simplify()(mod) - - mod = tir.transform.VectorizeLoop()(mod) - mod = tir.transform.StorageRewrite()(mod) - mod = tir.transform.UnrollLoop()(mod) - mod = tir.transform.RenormalizeSplitPattern()(mod) - mod = tir.transform.Simplify()(mod) - mod = tir.transform.RemoveNoOp()(mod) - mod = tir.transform.RewriteUnsafeSelect()(mod) - mod = tir.transform.HoistIfThenElse()(mod) - - mod = tir.transform.VerifyMemory()(mod) - mod = tir.transform.AnnotateEntryFunc()(mod) - mod = tir.transform.ThreadSync("shared")(mod) - # TODO(lei): This is a hack to make sure the - # thread level allreduce pass can be applied - # in TL. As Tl only use one thread dimension - # the var binding information will be lost - # in the lowering process with Legalization - # and Simplify pass. - # We can find a way better to create var instead - # of putting the LowerThreadAllreduce before - # the Legalization. - mod = tir.transform.LowerThreadAllreduce()(mod) - mod = tir.transform.ThreadSync("shared.dyn")(mod) - mod = tilelang.transform.LowerHopperIntrin()(mod) - mod = tir.transform.InjectPTXAsyncCopy()(mod) - - mod = tir.transform.AnnotateDeviceRegions()(mod) - mod = tir.transform.SplitHostDevice()(mod) - mod = tir.transform.MergeSharedMemoryAllocations()(mod) - mod = tir.transform.MakePackedAPI()(mod) - mod = tir.transform.LowerDeviceKernelLaunch()(mod) - + mod = LowerAndLegalize(mod, target) + mod = OptimizeForTarget(mod, target) device_mod = tir.transform.Filter(is_device_call)(mod) return device_mod diff --git a/tilelang/jit/adapter/cython/utils.py b/tilelang/jit/adapter/cython/utils.py index 287d63c72..c9a3384e5 100644 --- a/tilelang/jit/adapter/cython/utils.py +++ b/tilelang/jit/adapter/cython/utils.py @@ -11,6 +11,10 @@ determine_target, canon_target_host, ) +from tilelang.engine.phase import ( + LowerAndLegalize, + OptimizeForTarget, +) def match_global_kernel(source: str) -> int: @@ -47,57 +51,8 @@ def get_annotated_device_mod( target_host = tvm.target.Target.canon_target(target_host) target = tvm.target.Target(target, target_host) - mod = tir.transform.BindTarget(target)(mod) - - mod = tilelang.transform.FrontendLegalize()(mod) - mod = tir.transform.Simplify()(mod) - mod = tilelang.transform.LayoutInference()(mod) - mod = tilelang.transform.LowerTileOp()(mod) - mod = tir.transform.Simplify()(mod) - - if target.arch == "sm_90": - mod = tilelang.transform.WarpSpecializedPipeline()(mod) - else: - mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) - mod = tilelang.transform.PipelinePlanning()(mod) - mod = tilelang.transform.InjectSoftwarePipeline()(mod) - - mod = tir.transform.LowerOpaqueBlock()(mod) - mod = tir.transform.FlattenBuffer()(mod) - mod = tir.transform.NarrowDataType(32)(mod) - mod = tir.transform.Simplify()(mod) - - mod = tir.transform.VectorizeLoop()(mod) - mod = tir.transform.StorageRewrite()(mod) - mod = tir.transform.UnrollLoop()(mod) - mod = tir.transform.RenormalizeSplitPattern()(mod) - mod = tir.transform.Simplify()(mod) - mod = tir.transform.RemoveNoOp()(mod) - mod = tir.transform.RewriteUnsafeSelect()(mod) - mod = tir.transform.HoistIfThenElse()(mod) - - mod = tir.transform.VerifyMemory()(mod) - mod = tir.transform.AnnotateEntryFunc()(mod) - mod = tir.transform.ThreadSync("shared")(mod) - # TODO(lei): This is a hack to make sure the - # thread level allreduce pass can be applied - # in TL. As Tl only use one thread dimension - # the var binding information will be lost - # in the lowering process with Legalization - # and Simplify pass. - # We can find a way better to create var instead - # of putting the LowerThreadAllreduce before - # the Legalization. - mod = tir.transform.LowerThreadAllreduce()(mod) - mod = tir.transform.ThreadSync("shared.dyn")(mod) - mod = tilelang.transform.LowerHopperIntrin()(mod) - mod = tir.transform.InjectPTXAsyncCopy()(mod) - - mod = tir.transform.AnnotateDeviceRegions()(mod) - mod = tir.transform.SplitHostDevice()(mod) - mod = tir.transform.MergeSharedMemoryAllocations()(mod) - mod = tir.transform.MakePackedAPI()(mod) - mod = tir.transform.LowerDeviceKernelLaunch()(mod) + mod = LowerAndLegalize(mod, target) + mod = OptimizeForTarget(mod, target) device_mod = tir.transform.Filter(is_device_call)(mod) From 633a51ad564cf9f630d15f786dee7f1ac77504cc Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 25 Feb 2025 06:19:49 +0000 Subject: [PATCH 11/11] lintfix --- tilelang/engine/lower.py | 2 +- tilelang/engine/phase.py | 1 + tilelang/jit/adapter/cython/utils.py | 1 - 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index f33c49e9e..d9791f869 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. """The compiler for TL programs.""" -import tilelang as tl import os import os.path as osp from typing import Union, Optional, Callable @@ -17,6 +16,7 @@ OptimizeForTarget, ) + def is_cpu_device_backend(target: Target): return target.kind.name == "c" diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 3a94b5680..2ac15215c 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -5,6 +5,7 @@ from tvm.target import Target import tilelang as tl + def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: # Bind the target device information to the module mod = tir.transform.BindTarget(target)(mod) diff --git a/tilelang/jit/adapter/cython/utils.py b/tilelang/jit/adapter/cython/utils.py index c9a3384e5..c03c231e3 100644 --- a/tilelang/jit/adapter/cython/utils.py +++ b/tilelang/jit/adapter/cython/utils.py @@ -5,7 +5,6 @@ from tilelang import tvm as tvm from tvm import IRModule, tir from tvm.target import Target -import tilelang.transform from tilelang.engine.lower import ( is_device_call, determine_target,