Skip to content

Commit 1624981

Browse files
committed
[Refactor] Update example_mla_decode.py and add tests for block_sparse_attn_tilelang
* Refactor example_mla_decode.py to define a main function for better structure and clarity. * Introduce test_example_mla_decode.py to validate the functionality of example_mla_decode. * Refactor block_sparse_attn_tilelang.py to define a main function and add test_block_sparse_attn_tilelang.py for testing. * Ensure all new test files are integrated with tilelang testing framework.
1 parent 9cfa724 commit 1624981

File tree

4 files changed

+38
-2
lines changed

4 files changed

+38
-2
lines changed

examples/deepseek_mla/example_mla_decode.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
272272
return out
273273

274274

275-
if __name__ == "__main__":
275+
def main():
276276
parser = argparse.ArgumentParser()
277277
parser.add_argument('--batch', type=int, default=128, help='batch size')
278278
parser.add_argument('--heads', type=int, default=128, help='q heads number')
@@ -296,3 +296,7 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
296296
latency = profiler.do_bench(warmup=500)
297297
print(f"Latency: {latency} ms")
298298
print(f"TFlops: {total_flops / latency * 1e-9} TFlops")
299+
300+
301+
if __name__ == "__main__":
302+
main()
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) Tile-AI Corporation.
2+
# Licensed under the MIT License.
3+
import tilelang.testing
4+
5+
import example_mla_decode
6+
7+
8+
@tilelang.testing.requires_cuda
9+
def test_example_mla_decode():
10+
example_mla_decode.main()
11+
12+
13+
if __name__ == "__main__":
14+
tilelang.testing.main()

examples/seer_attention/block_sparse_attn_tilelang.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,10 @@ def test_topk_sparse_attention_qlen_lt_klen():
261261
print("Pass topk sparse attention test with qlen < klen")
262262

263263

264-
if __name__ == "__main__":
264+
def main():
265265
test_topk_sparse_attention()
266266
test_topk_sparse_attention_qlen_lt_klen()
267+
268+
269+
if __name__ == "__main__":
270+
main()
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) Tile-AI Corporation.
2+
# Licensed under the MIT License.
3+
import tilelang.testing
4+
5+
import block_sparse_attn_tilelang
6+
7+
8+
@tilelang.testing.requires_cuda
9+
def test_block_sparse_attn_tilelang():
10+
block_sparse_attn_tilelang.main()
11+
12+
13+
if __name__ == "__main__":
14+
tilelang.testing.main()

0 commit comments

Comments
 (0)