diff --git a/examples/analyze/example_conv_analyze.py b/examples/analyze/example_conv_analyze.py index 8b7c323eb..414183a70 100644 --- a/examples/analyze/example_conv_analyze.py +++ b/examples/analyze/example_conv_analyze.py @@ -49,7 +49,7 @@ def kernel(N, is_hopper = check_hopper() @T.prim_func - def main( + def conv( data: T.Tensor((N, H, W, C), dtype), kernel: T.Tensor((KH, KW, C, F), dtype), out: T.Tensor((N, OH, OW, F), dtype), @@ -91,11 +91,16 @@ def main( T.copy(out_local, out_shared) T.copy(out_shared, out_flat[by * block_M, bx * block_N]) - return main + return conv -my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256) -cuda_device = CUDA("cuda") -result = Analyzer.analysis(my_func, cuda_device) -print(result) -print(f"Analyzed FLOPs: {result.total_flops}") +def main(): + my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256) + cuda_device = CUDA("cuda") + result = Analyzer.analysis(my_func, cuda_device) + print(result) + print(f"Analyzed FLOPs: {result.total_flops}") + + +if __name__ == "__main__": + main() diff --git a/examples/analyze/example_gemm_analyze.py b/examples/analyze/example_gemm_analyze.py index 7ea272ff7..772913db2 100644 --- a/examples/analyze/example_gemm_analyze.py +++ b/examples/analyze/example_gemm_analyze.py @@ -19,7 +19,7 @@ def kernel( accum_dtype = "float" @T.prim_func - def main( + def matmul( A: T.Tensor((M, K), dtype), B: T.Tensor((N, K), dtype), C: T.Tensor((M, N), dtype), @@ -43,13 +43,18 @@ def main( T.copy(C_local, C_shared) T.copy(C_shared, C[by * block_M, bx * block_N]) - return main + return matmul -my_func = kernel(128, 128, 32, 3, 128, True) +def main(): + my_func = kernel(128, 128, 32, 3, 128, True) -cuda_device = CUDA("cuda") -result = Analyzer.analysis(my_func, cuda_device) + cuda_device = CUDA("cuda") + result = Analyzer.analysis(my_func, cuda_device) -print(f"Analyzed FLOPs: {result.total_flops}") -print(f"Expected FLOPs: {2 * M * N * K}") + print(f"Analyzed FLOPs: {result.total_flops}") + print(f"Expected FLOPs: {2 * M * N * K}") + + +if __name__ == "__main__": + main() diff --git a/examples/analyze/test_example_analyze.py b/examples/analyze/test_example_analyze.py new file mode 100644 index 000000000..67af27056 --- /dev/null +++ b/examples/analyze/test_example_analyze.py @@ -0,0 +1,17 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. +import tilelang.testing +import example_gemm_analyze +import example_conv_analyze + + +def test_example_gemm_analyze(): + example_gemm_analyze.main() + + +def test_example_conv_analyze(): + example_conv_analyze.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/blocksparse_attention/block_sparse_attn_triton.py b/examples/blocksparse_attention/block_sparse_attn_triton.py index 70f20af26..0c41c99e4 100644 --- a/examples/blocksparse_attention/block_sparse_attn_triton.py +++ b/examples/blocksparse_attention/block_sparse_attn_triton.py @@ -379,6 +379,10 @@ def test_topk_sparse_attention_qlt_kl(): print("Pass topk sparse attention test with qlen < klen") -if __name__ == "__main__": +def main(): test_topk_sparse_attention() test_topk_sparse_attention_qlt_kl() + + +if __name__ == "__main__": + main() diff --git a/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py index 48512e68d..169e02d07 100644 --- a/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py +++ b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py @@ -118,7 +118,7 @@ def Rescale( acc_o[i, j] *= scores_scale[i] @T.prim_func - def main( + def blocksparse_flashattn( Q: T.Tensor(shape, dtype), K: T.Tensor(shape, dtype), V: T.Tensor(shape, dtype), @@ -165,7 +165,7 @@ def main( T.copy(acc_o, O_shared) T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) - return main + return blocksparse_flashattn return kernel_func(block_M, block_N, num_stages, threads) @@ -219,5 +219,9 @@ def test_topk_sparse_attention(): print("Pass topk sparse attention test with qlen == klen") -if __name__ == "__main__": +def main(): test_topk_sparse_attention() + + +if __name__ == "__main__": + main() diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py index a8623ca60..03854b338 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py @@ -223,6 +223,7 @@ def forward(self, query, key, value, block_indices, cache_seqlens): heads = self.heads heads_kv = self.heads_kv dim_v = self.dim_v + dim = self.dim block_size = self.block_size max_selected_blocks = block_indices.shape[-1] @@ -397,30 +398,20 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3): print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}") -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') - parser.add_argument('--block_size', type=int, default=32, help='block_size') - args = parser.parse_args() - - batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v - sparse_ratio = args.sparse_ratio - block_size = args.block_size - qk_flops = 2 * batch * heads * max_cache_seqlen * dim - pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v - total_flops = qk_flops + pv_flops - +def main(batch=8, + heads=32, + heads_kv=8, + max_cache_seqlen=8192, + dim=128, + dim_v=128, + sparse_ratio=0.8, + block_size=32): + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + sparse_ratio = sparse_ratio + block_size = block_size max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) print("max_selected_blocks: ", max_selected_blocks) dtype = torch.float16 - block_H = 64 Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') @@ -494,3 +485,19 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3): out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) torch.cuda.synchronize() print("sparse time: ", (time.time() - start) / 100 * 1000) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch', type=int, default=8, help='batch size') + parser.add_argument('--heads', type=int, default=32, help='heads') + parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') + parser.add_argument( + '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') + parser.add_argument('--dim', type=int, default=128, help='dim') + parser.add_argument('--dim_v', type=int, default=128, help='dim_v') + parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') + parser.add_argument('--block_size', type=int, default=32, help='block_size') + args = parser.parse_args() + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, + args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py index e430cbf2c..0c43889f8 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py @@ -209,6 +209,7 @@ def forward(self, query, key, value, block_mask, cache_seqlens): heads = self.heads heads_kv = self.heads_kv dim_v = self.dim_v + dim = self.dim block_size = self.block_size block_H = self.block_H max_cache_seqlen = key.shape[1] @@ -370,30 +371,20 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3): print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}") -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=8, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') - parser.add_argument('--block_size', type=int, default=32, help='block_size') - args = parser.parse_args() - - batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v - sparse_ratio = args.sparse_ratio - block_size = args.block_size - qk_flops = 2 * batch * heads * max_cache_seqlen * dim - pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v - total_flops = qk_flops + pv_flops - +def main(batch=8, + heads=32, + heads_kv=8, + max_cache_seqlen=8192, + dim=128, + dim_v=128, + sparse_ratio=0.8, + block_size=32): + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + sparse_ratio = sparse_ratio + block_size = block_size max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) print("max_selected_blocks: ", max_selected_blocks) dtype = torch.float16 - block_H = 64 Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') @@ -457,3 +448,19 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3): out = model(Q, K, V, block_mask, cache_seqlens) torch.cuda.synchronize() print("sparse time: ", (time.time() - start) / 100 * 1000) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch', type=int, default=8, help='batch size') + parser.add_argument('--heads', type=int, default=32, help='heads') + parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') + parser.add_argument( + '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') + parser.add_argument('--dim', type=int, default=128, help='dim') + parser.add_argument('--dim_v', type=int, default=128, help='dim_v') + parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') + parser.add_argument('--block_size', type=int, default=32, help='block_size') + args = parser.parse_args() + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, + args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py index 887567e4f..316db6b2c 100644 --- a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py @@ -350,22 +350,18 @@ def ref_program_fa(query, key, value, cache_seqlens): return output -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=64, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') - parser.add_argument('--block_size', type=int, default=32, help='block_size') - args = parser.parse_args() - - batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v - sparse_ratio = args.sparse_ratio - block_size = args.block_size +def main(batch=64, + heads=32, + heads_kv=8, + max_cache_seqlen=8192, + dim=128, + dim_v=128, + sparse_ratio=0.8, + block_size=32): + + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + sparse_ratio = sparse_ratio + block_size = block_size qk_flops = 2 * batch * heads * max_cache_seqlen * dim pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v total_flops = qk_flops + pv_flops @@ -466,3 +462,19 @@ def ref_program_fa(query, key, value, cache_seqlens): print(f"Average time of ref: {avg_time_ref:.6f} seconds") print(f"Speedup: {avg_time_ref / avg_time:.2f}x") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch', type=int, default=64, help='batch size') + parser.add_argument('--heads', type=int, default=32, help='heads') + parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') + parser.add_argument( + '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') + parser.add_argument('--dim', type=int, default=128, help='dim') + parser.add_argument('--dim_v', type=int, default=128, help='dim_v') + parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') + parser.add_argument('--block_size', type=int, default=32, help='block_size') + args = parser.parse_args() + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, + args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py index d5011532a..15288398f 100644 --- a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py @@ -348,28 +348,23 @@ def ref_program_fa(query, key, value, cache_seqlens): return output -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=64, help='batch size') - parser.add_argument('--heads', type=int, default=32, help='heads') - parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') - parser.add_argument( - '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') - parser.add_argument('--dim', type=int, default=128, help='dim') - parser.add_argument('--dim_v', type=int, default=128, help='dim_v') - parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') - parser.add_argument('--block_size', type=int, default=32, help='block_size') - args = parser.parse_args() - - batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v - block_size = args.block_size - sparse_ratio = args.sparse_ratio +def main(batch=64, + heads=32, + heads_kv=8, + max_cache_seqlen=8192, + dim=128, + dim_v=128, + sparse_ratio=0.8, + block_size=32): + + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + block_size = block_size + sparse_ratio = sparse_ratio qk_flops = 2 * batch * heads * max_cache_seqlen * dim pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v total_flops = qk_flops + pv_flops dtype = torch.float16 - block_H = 64 Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') @@ -435,6 +430,7 @@ def ref_program_fa(query, key, value, cache_seqlens): avg_time = elapsed_time / 1000 avg_flops = total_flops / avg_time print(f"Average time: {avg_time:.6f} seconds") + print(f"Average flops: {avg_flops:.2f} GFLOPS") # Measure performance of reference implementation start = time.time() @@ -446,5 +442,22 @@ def ref_program_fa(query, key, value, cache_seqlens): avg_time_ref = elapsed_time_ref / 1000 avg_flops_ref = total_flops / avg_time_ref print(f"Average time of ref: {avg_time_ref:.6f} seconds") + print(f"Average flops of ref: {avg_flops_ref:.2f} GFLOPS") print(f"Speedup: {avg_time_ref / avg_time:.2f}x") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch', type=int, default=64, help='batch size') + parser.add_argument('--heads', type=int, default=32, help='heads') + parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv') + parser.add_argument( + '--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length') + parser.add_argument('--dim', type=int, default=128, help='dim') + parser.add_argument('--dim_v', type=int, default=128, help='dim_v') + parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio') + parser.add_argument('--block_size', type=int, default=32, help='block_size') + args = parser.parse_args() + main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v, + args.sparse_ratio, args.block_size) diff --git a/examples/blocksparse_attention/test_example_blocksparse_attention.py b/examples/blocksparse_attention/test_example_blocksparse_attention.py new file mode 100644 index 000000000..7ba33ed2e --- /dev/null +++ b/examples/blocksparse_attention/test_example_blocksparse_attention.py @@ -0,0 +1,37 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. +import tilelang.testing +import block_sparse_attn_triton +import example_tilelang_block_sparse_attn +import example_tilelang_sparse_gqa_decode_varlen_indice +import example_tilelang_sparse_gqa_decode_varlen_mask +import example_triton_sparse_gqa_decode_varlen_indice +import example_triton_sparse_gqa_decode_varlen_mask + + +def test_block_sparse_attn_triton(): + block_sparse_attn_triton.main() + + +def test_example_tilelang_block_sparse_attn(): + example_tilelang_block_sparse_attn.main() + + +def test_example_tilelang_sparse_gqa_decode_varlen_indice(): + example_tilelang_sparse_gqa_decode_varlen_indice.main() + + +def test_example_tilelang_sparse_gqa_decode_varlen_mask(): + example_tilelang_sparse_gqa_decode_varlen_mask.main() + + +def test_example_triton_sparse_gqa_decode_varlen_indice(): + example_triton_sparse_gqa_decode_varlen_indice.main() + + +def test_example_triton_sparse_gqa_decode_varlen_mask(): + example_triton_sparse_gqa_decode_varlen_mask.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/convolution/example_convolution.py b/examples/convolution/example_convolution.py index 87612289f..b4d286eda 100644 --- a/examples/convolution/example_convolution.py +++ b/examples/convolution/example_convolution.py @@ -274,4 +274,4 @@ def main(argv=None): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/requirements-test.txt b/requirements-test.txt index 76f6eca57..f06cbb604 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -24,4 +24,9 @@ thefuzz tabulate wheel setuptools -einops \ No newline at end of file +einops +attrs +decorator +flash-attn +scipy +tornado \ No newline at end of file