Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
import argparse
import itertools
import torch
Expand All @@ -13,7 +15,7 @@ def ref_program(x, y):
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):

@T.prim_func
def main(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N),
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor((M, N),
out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
start_x = bx * block_N
Expand All @@ -23,7 +25,7 @@ def main(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tens
x = start_x + local_x
C[y, x] = A[y, x] + B[y, x]

return main
return elem_add


def get_configs(M, N):
Expand All @@ -49,13 +51,12 @@ def kernel(block_M=None, block_N=None, threads=None):
)
return autotuner.run(warmup=3, rep=20)


if __name__ == "__main__":
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=512)
parser.add_argument("--n", type=int, default=1024)
parser.add_argument("--use_autotune", action="store_true", default=False)
args = parser.parse_args()
args, _ = parser.parse_known_args()
M, N = args.m, args.n

a = torch.randn(M, N, dtype=torch.float32, device="cuda")
Expand All @@ -72,3 +73,7 @@ def kernel(block_M=None, block_N=None, threads=None):

out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)


if __name__ == "__main__":
main()
12 changes: 12 additions & 0 deletions examples/elementwise/test_example_elementwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
import tilelang.testing
import example_elementwise_add


def test_example_elementwise_add():
example_elementwise_add.main()


if __name__ == "__main__":
tilelang.testing.main()
10 changes: 6 additions & 4 deletions examples/gemv/example_gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,11 @@ def check_correctness_and_bench(kernel, N, K, bench_ref=True):
print(f"TileLang Latency: {latency} ms\n")


if __name__ == "__main__":
def main():
parser = argparse.ArgumentParser(description="GEMV Example")
parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K")
args = parser.parse_args()
args, _ = parser.parse_known_args()
N, K = args.n, args.k
check_correctness_and_bench(naive_gemv(N, K, 128, 128), N, K)
check_correctness_and_bench(naive_splitk_gemv(N, K, 32, 32), N, K)
Expand All @@ -318,13 +318,15 @@ def check_correctness_and_bench(kernel, N, K, bench_ref=True):
print("Test passed!")

best_result = get_best_config(N, K)
best_latency = best_result.latency
best_config = best_result.config
ref_latency = best_result.ref_latency
kernel = splitk_gemv_vectorized_tvm(N, K, *best_config)
kernel = tl.compile(kernel, out_idx=-1)
profiler = kernel.get_profiler()
latency = profiler.do_bench(lambda x, y: x @ y.T, warmup=500)
print(f"Torch Latency: {latency} ms")
latency = profiler.do_bench(kernel, warmup=500)
print(f"TileLang Latency: {latency} ms\n")


if __name__ == "__main__":
main()
13 changes: 13 additions & 0 deletions examples/gemv/test_example_gemv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
import tilelang.testing

import example_gemv


def test_example_gemv():
example_gemv.main()


if __name__ == "__main__":
tilelang.testing.main()