Skip to content

Commit 912ead3

Browse files
yyttt6LeiWang1999
andauthored
[CI] Add Analyzer and blocksparse_attention examples to CI (#472)
* yes * [Bugfix] fix the unexpected keyword error of autotune * format * test * [CI] Add Analyzer and blocksparse_attention examples to CI * format * try * try * try * try * t * format * d --------- Co-authored-by: Lei Wang <34334180+LeiWang1999@users.noreply.github.com>
1 parent 7bea218 commit 912ead3

12 files changed

+211
-95
lines changed

examples/analyze/example_conv_analyze.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def kernel(N,
4949
is_hopper = check_hopper()
5050

5151
@T.prim_func
52-
def main(
52+
def conv(
5353
data: T.Tensor((N, H, W, C), dtype),
5454
kernel: T.Tensor((KH, KW, C, F), dtype),
5555
out: T.Tensor((N, OH, OW, F), dtype),
@@ -91,11 +91,16 @@ def main(
9191
T.copy(out_local, out_shared)
9292
T.copy(out_shared, out_flat[by * block_M, bx * block_N])
9393

94-
return main
94+
return conv
9595

9696

97-
my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256)
98-
cuda_device = CUDA("cuda")
99-
result = Analyzer.analysis(my_func, cuda_device)
100-
print(result)
101-
print(f"Analyzed FLOPs: {result.total_flops}")
97+
def main():
98+
my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256)
99+
cuda_device = CUDA("cuda")
100+
result = Analyzer.analysis(my_func, cuda_device)
101+
print(result)
102+
print(f"Analyzed FLOPs: {result.total_flops}")
103+
104+
105+
if __name__ == "__main__":
106+
main()

examples/analyze/example_gemm_analyze.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def kernel(
1919
accum_dtype = "float"
2020

2121
@T.prim_func
22-
def main(
22+
def matmul(
2323
A: T.Tensor((M, K), dtype),
2424
B: T.Tensor((N, K), dtype),
2525
C: T.Tensor((M, N), dtype),
@@ -43,13 +43,18 @@ def main(
4343
T.copy(C_local, C_shared)
4444
T.copy(C_shared, C[by * block_M, bx * block_N])
4545

46-
return main
46+
return matmul
4747

4848

49-
my_func = kernel(128, 128, 32, 3, 128, True)
49+
def main():
50+
my_func = kernel(128, 128, 32, 3, 128, True)
5051

51-
cuda_device = CUDA("cuda")
52-
result = Analyzer.analysis(my_func, cuda_device)
52+
cuda_device = CUDA("cuda")
53+
result = Analyzer.analysis(my_func, cuda_device)
5354

54-
print(f"Analyzed FLOPs: {result.total_flops}")
55-
print(f"Expected FLOPs: {2 * M * N * K}")
55+
print(f"Analyzed FLOPs: {result.total_flops}")
56+
print(f"Expected FLOPs: {2 * M * N * K}")
57+
58+
59+
if __name__ == "__main__":
60+
main()
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) Tile-AI Corporation.
2+
# Licensed under the MIT License.
3+
import tilelang.testing
4+
import example_gemm_analyze
5+
import example_conv_analyze
6+
7+
8+
def test_example_gemm_analyze():
9+
example_gemm_analyze.main()
10+
11+
12+
def test_example_conv_analyze():
13+
example_conv_analyze.main()
14+
15+
16+
if __name__ == "__main__":
17+
tilelang.testing.main()

examples/blocksparse_attention/block_sparse_attn_triton.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,10 @@ def test_topk_sparse_attention_qlt_kl():
379379
print("Pass topk sparse attention test with qlen < klen")
380380

381381

382-
if __name__ == "__main__":
382+
def main():
383383
test_topk_sparse_attention()
384384
test_topk_sparse_attention_qlt_kl()
385+
386+
387+
if __name__ == "__main__":
388+
main()

examples/blocksparse_attention/example_tilelang_block_sparse_attn.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def Rescale(
118118
acc_o[i, j] *= scores_scale[i]
119119

120120
@T.prim_func
121-
def main(
121+
def blocksparse_flashattn(
122122
Q: T.Tensor(shape, dtype),
123123
K: T.Tensor(shape, dtype),
124124
V: T.Tensor(shape, dtype),
@@ -165,7 +165,7 @@ def main(
165165
T.copy(acc_o, O_shared)
166166
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
167167

168-
return main
168+
return blocksparse_flashattn
169169

170170
return kernel_func(block_M, block_N, num_stages, threads)
171171

@@ -219,5 +219,9 @@ def test_topk_sparse_attention():
219219
print("Pass topk sparse attention test with qlen == klen")
220220

221221

222-
if __name__ == "__main__":
222+
def main():
223223
test_topk_sparse_attention()
224+
225+
226+
if __name__ == "__main__":
227+
main()

examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def forward(self, query, key, value, block_indices, cache_seqlens):
223223
heads = self.heads
224224
heads_kv = self.heads_kv
225225
dim_v = self.dim_v
226+
dim = self.dim
226227
block_size = self.block_size
227228
max_selected_blocks = block_indices.shape[-1]
228229

@@ -397,30 +398,20 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
397398
print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}")
398399

399400

400-
if __name__ == "__main__":
401-
parser = argparse.ArgumentParser()
402-
parser.add_argument('--batch', type=int, default=8, help='batch size')
403-
parser.add_argument('--heads', type=int, default=32, help='heads')
404-
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
405-
parser.add_argument(
406-
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
407-
parser.add_argument('--dim', type=int, default=128, help='dim')
408-
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
409-
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio')
410-
parser.add_argument('--block_size', type=int, default=32, help='block_size')
411-
args = parser.parse_args()
412-
413-
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v
414-
sparse_ratio = args.sparse_ratio
415-
block_size = args.block_size
416-
qk_flops = 2 * batch * heads * max_cache_seqlen * dim
417-
pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v
418-
total_flops = qk_flops + pv_flops
419-
401+
def main(batch=8,
402+
heads=32,
403+
heads_kv=8,
404+
max_cache_seqlen=8192,
405+
dim=128,
406+
dim_v=128,
407+
sparse_ratio=0.8,
408+
block_size=32):
409+
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
410+
sparse_ratio = sparse_ratio
411+
block_size = block_size
420412
max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size))
421413
print("max_selected_blocks: ", max_selected_blocks)
422414
dtype = torch.float16
423-
block_H = 64
424415

425416
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda')
426417
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda')
@@ -494,3 +485,19 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
494485
out = sparse_kernel(Q, K, V, block_indices, cache_seqlens)
495486
torch.cuda.synchronize()
496487
print("sparse time: ", (time.time() - start) / 100 * 1000)
488+
489+
490+
if __name__ == "__main__":
491+
parser = argparse.ArgumentParser()
492+
parser.add_argument('--batch', type=int, default=8, help='batch size')
493+
parser.add_argument('--heads', type=int, default=32, help='heads')
494+
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
495+
parser.add_argument(
496+
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
497+
parser.add_argument('--dim', type=int, default=128, help='dim')
498+
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
499+
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio')
500+
parser.add_argument('--block_size', type=int, default=32, help='block_size')
501+
args = parser.parse_args()
502+
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v,
503+
args.sparse_ratio, args.block_size)

examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def forward(self, query, key, value, block_mask, cache_seqlens):
209209
heads = self.heads
210210
heads_kv = self.heads_kv
211211
dim_v = self.dim_v
212+
dim = self.dim
212213
block_size = self.block_size
213214
block_H = self.block_H
214215
max_cache_seqlen = key.shape[1]
@@ -370,30 +371,20 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
370371
print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}")
371372

372373

373-
if __name__ == "__main__":
374-
parser = argparse.ArgumentParser()
375-
parser.add_argument('--batch', type=int, default=8, help='batch size')
376-
parser.add_argument('--heads', type=int, default=32, help='heads')
377-
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
378-
parser.add_argument(
379-
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
380-
parser.add_argument('--dim', type=int, default=128, help='dim')
381-
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
382-
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio')
383-
parser.add_argument('--block_size', type=int, default=32, help='block_size')
384-
args = parser.parse_args()
385-
386-
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v
387-
sparse_ratio = args.sparse_ratio
388-
block_size = args.block_size
389-
qk_flops = 2 * batch * heads * max_cache_seqlen * dim
390-
pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v
391-
total_flops = qk_flops + pv_flops
392-
374+
def main(batch=8,
375+
heads=32,
376+
heads_kv=8,
377+
max_cache_seqlen=8192,
378+
dim=128,
379+
dim_v=128,
380+
sparse_ratio=0.8,
381+
block_size=32):
382+
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
383+
sparse_ratio = sparse_ratio
384+
block_size = block_size
393385
max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size))
394386
print("max_selected_blocks: ", max_selected_blocks)
395387
dtype = torch.float16
396-
block_H = 64
397388

398389
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda')
399390
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda')
@@ -457,3 +448,19 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3):
457448
out = model(Q, K, V, block_mask, cache_seqlens)
458449
torch.cuda.synchronize()
459450
print("sparse time: ", (time.time() - start) / 100 * 1000)
451+
452+
453+
if __name__ == "__main__":
454+
parser = argparse.ArgumentParser()
455+
parser.add_argument('--batch', type=int, default=8, help='batch size')
456+
parser.add_argument('--heads', type=int, default=32, help='heads')
457+
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
458+
parser.add_argument(
459+
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
460+
parser.add_argument('--dim', type=int, default=128, help='dim')
461+
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
462+
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio')
463+
parser.add_argument('--block_size', type=int, default=32, help='block_size')
464+
args = parser.parse_args()
465+
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v,
466+
args.sparse_ratio, args.block_size)

examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -350,22 +350,18 @@ def ref_program_fa(query, key, value, cache_seqlens):
350350
return output
351351

352352

353-
if __name__ == "__main__":
354-
parser = argparse.ArgumentParser()
355-
parser.add_argument('--batch', type=int, default=64, help='batch size')
356-
parser.add_argument('--heads', type=int, default=32, help='heads')
357-
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
358-
parser.add_argument(
359-
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
360-
parser.add_argument('--dim', type=int, default=128, help='dim')
361-
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
362-
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio')
363-
parser.add_argument('--block_size', type=int, default=32, help='block_size')
364-
args = parser.parse_args()
365-
366-
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v
367-
sparse_ratio = args.sparse_ratio
368-
block_size = args.block_size
353+
def main(batch=64,
354+
heads=32,
355+
heads_kv=8,
356+
max_cache_seqlen=8192,
357+
dim=128,
358+
dim_v=128,
359+
sparse_ratio=0.8,
360+
block_size=32):
361+
362+
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
363+
sparse_ratio = sparse_ratio
364+
block_size = block_size
369365
qk_flops = 2 * batch * heads * max_cache_seqlen * dim
370366
pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v
371367
total_flops = qk_flops + pv_flops
@@ -466,3 +462,19 @@ def ref_program_fa(query, key, value, cache_seqlens):
466462
print(f"Average time of ref: {avg_time_ref:.6f} seconds")
467463

468464
print(f"Speedup: {avg_time_ref / avg_time:.2f}x")
465+
466+
467+
if __name__ == "__main__":
468+
parser = argparse.ArgumentParser()
469+
parser.add_argument('--batch', type=int, default=64, help='batch size')
470+
parser.add_argument('--heads', type=int, default=32, help='heads')
471+
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
472+
parser.add_argument(
473+
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
474+
parser.add_argument('--dim', type=int, default=128, help='dim')
475+
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
476+
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio')
477+
parser.add_argument('--block_size', type=int, default=32, help='block_size')
478+
args = parser.parse_args()
479+
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v,
480+
args.sparse_ratio, args.block_size)

examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -348,28 +348,23 @@ def ref_program_fa(query, key, value, cache_seqlens):
348348
return output
349349

350350

351-
if __name__ == "__main__":
352-
parser = argparse.ArgumentParser()
353-
parser.add_argument('--batch', type=int, default=64, help='batch size')
354-
parser.add_argument('--heads', type=int, default=32, help='heads')
355-
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
356-
parser.add_argument(
357-
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
358-
parser.add_argument('--dim', type=int, default=128, help='dim')
359-
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
360-
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio')
361-
parser.add_argument('--block_size', type=int, default=32, help='block_size')
362-
args = parser.parse_args()
363-
364-
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v
365-
block_size = args.block_size
366-
sparse_ratio = args.sparse_ratio
351+
def main(batch=64,
352+
heads=32,
353+
heads_kv=8,
354+
max_cache_seqlen=8192,
355+
dim=128,
356+
dim_v=128,
357+
sparse_ratio=0.8,
358+
block_size=32):
359+
360+
batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v
361+
block_size = block_size
362+
sparse_ratio = sparse_ratio
367363
qk_flops = 2 * batch * heads * max_cache_seqlen * dim
368364
pv_flops = 2 * batch * heads * max_cache_seqlen * dim_v
369365
total_flops = qk_flops + pv_flops
370366

371367
dtype = torch.float16
372-
block_H = 64
373368

374369
Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda')
375370
K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda')
@@ -435,6 +430,7 @@ def ref_program_fa(query, key, value, cache_seqlens):
435430
avg_time = elapsed_time / 1000
436431
avg_flops = total_flops / avg_time
437432
print(f"Average time: {avg_time:.6f} seconds")
433+
print(f"Average flops: {avg_flops:.2f} GFLOPS")
438434

439435
# Measure performance of reference implementation
440436
start = time.time()
@@ -446,5 +442,22 @@ def ref_program_fa(query, key, value, cache_seqlens):
446442
avg_time_ref = elapsed_time_ref / 1000
447443
avg_flops_ref = total_flops / avg_time_ref
448444
print(f"Average time of ref: {avg_time_ref:.6f} seconds")
445+
print(f"Average flops of ref: {avg_flops_ref:.2f} GFLOPS")
449446

450447
print(f"Speedup: {avg_time_ref / avg_time:.2f}x")
448+
449+
450+
if __name__ == "__main__":
451+
parser = argparse.ArgumentParser()
452+
parser.add_argument('--batch', type=int, default=64, help='batch size')
453+
parser.add_argument('--heads', type=int, default=32, help='heads')
454+
parser.add_argument('--heads_kv', type=int, default=8, help='heads_kv')
455+
parser.add_argument(
456+
'--max_cache_seqlen', type=int, default=8192, help='kvcache sequence length')
457+
parser.add_argument('--dim', type=int, default=128, help='dim')
458+
parser.add_argument('--dim_v', type=int, default=128, help='dim_v')
459+
parser.add_argument('--sparse_ratio', type=float, default=0.8, help='sparse ratio')
460+
parser.add_argument('--block_size', type=int, default=32, help='block_size')
461+
args = parser.parse_args()
462+
main(args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v,
463+
args.sparse_ratio, args.block_size)

0 commit comments

Comments
 (0)