Skip to content

Commit a29649b

Browse files
authored
[CI] Add elementwise and gemv examples to CI. (#458)
* [CI] Add elementwise and gemv examples to CI. * fix lint * test * fix gemv lint * fix lint
1 parent 620b8e0 commit a29649b

File tree

4 files changed

+41
-9
lines changed

4 files changed

+41
-9
lines changed

examples/elementwise/elementwise_add.py renamed to examples/elementwise/example_elementwise_add.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# Copyright (c) Tile-AI Corporation.
2+
# Licensed under the MIT License.
13
import argparse
24
import itertools
35
import torch
@@ -13,7 +15,7 @@ def ref_program(x, y):
1315
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
1416

1517
@T.prim_func
16-
def main(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N),
18+
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N),
1719
out_dtype)):
1820
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
1921
start_x = bx * block_N
@@ -23,7 +25,7 @@ def main(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tens
2325
x = start_x + local_x
2426
C[y, x] = A[y, x] + B[y, x]
2527

26-
return main
28+
return elem_add
2729

2830

2931
def get_configs(M, N):
@@ -49,13 +51,12 @@ def kernel(block_M=None, block_N=None, threads=None):
4951
)
5052
return autotuner.run(warmup=3, rep=20)
5153

52-
53-
if __name__ == "__main__":
54+
def main():
5455
parser = argparse.ArgumentParser()
5556
parser.add_argument("--m", type=int, default=512)
5657
parser.add_argument("--n", type=int, default=1024)
5758
parser.add_argument("--use_autotune", action="store_true", default=False)
58-
args = parser.parse_args()
59+
args, _ = parser.parse_known_args()
5960
M, N = args.m, args.n
6061

6162
a = torch.randn(M, N, dtype=torch.float32, device="cuda")
@@ -72,3 +73,7 @@ def kernel(block_M=None, block_N=None, threads=None):
7273

7374
out = kernel(a, b)
7475
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
76+
77+
78+
if __name__ == "__main__":
79+
main()
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright (c) Tile-AI Corporation.
2+
# Licensed under the MIT License.
3+
import tilelang.testing
4+
import example_elementwise_add
5+
6+
7+
def test_example_elementwise_add():
8+
example_elementwise_add.main()
9+
10+
11+
if __name__ == "__main__":
12+
tilelang.testing.main()

examples/gemv/example_gemv.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,11 +304,11 @@ def check_correctness_and_bench(kernel, N, K, bench_ref=True):
304304
print(f"TileLang Latency: {latency} ms\n")
305305

306306

307-
if __name__ == "__main__":
307+
def main():
308308
parser = argparse.ArgumentParser(description="GEMV Example")
309309
parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N")
310310
parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
311-
args = parser.parse_args()
311+
args, _ = parser.parse_known_args()
312312
N, K = args.n, args.k
313313
check_correctness_and_bench(naive_gemv(N, K, 128, 128), N, K)
314314
check_correctness_and_bench(naive_splitk_gemv(N, K, 32, 32), N, K)
@@ -318,13 +318,15 @@ def check_correctness_and_bench(kernel, N, K, bench_ref=True):
318318
print("Test passed!")
319319

320320
best_result = get_best_config(N, K)
321-
best_latency = best_result.latency
322321
best_config = best_result.config
323-
ref_latency = best_result.ref_latency
324322
kernel = splitk_gemv_vectorized_tvm(N, K, *best_config)
325323
kernel = tl.compile(kernel, out_idx=-1)
326324
profiler = kernel.get_profiler()
327325
latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=500)
328326
print(f"Torch Latency: {latency} ms")
329327
latency = profiler.do_bench(kernel, warmup=500)
330328
print(f"TileLang Latency: {latency} ms\n")
329+
330+
331+
if __name__ == "__main__":
332+
main()

examples/gemv/test_example_gemv.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) Tile-AI Corporation.
2+
# Licensed under the MIT License.
3+
import tilelang.testing
4+
5+
import example_gemv
6+
7+
8+
def test_example_gemv():
9+
example_gemv.main()
10+
11+
12+
if __name__ == "__main__":
13+
tilelang.testing.main()

0 commit comments

Comments
 (0)