Skip to content

Commit 73d866e

Browse files
authored
Merge branch 'main' into main
2 parents 9d1d0e3 + d79a41b commit 73d866e

File tree

2 files changed

+24
-26
lines changed

2 files changed

+24
-26
lines changed

examples/convolution/example_convolution.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -232,28 +232,7 @@ def main(
232232
return main
233233

234234

235-
def main(n=128, c=128, h=64, w=64, f=128, k=3, s=1, d=1, p=1, use_autotune=True, with_roller=True):
236-
N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p
237-
a = torch.randn(N, H, W, C).cuda().half()
238-
b = torch.randn(K, K, C, F).cuda().half()
239-
use_autotune = use_autotune
240-
with_roller = with_roller
241-
if use_autotune:
242-
result = get_best_config(N, C, H, W, F, K, S, D, P, with_roller)
243-
print(f"best latency {result.latency}")
244-
kernel = result.kernel
245-
else:
246-
kernel = tilelang.compile(
247-
convolution(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256), out_idx=[2])
248-
249-
out_c = kernel(a, b)
250-
ref_c = ref_program(S, P, D)(a, b)
251-
print(out_c)
252-
print(ref_c)
253-
# torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2)
254-
255-
256-
if __name__ == "__main__":
235+
def main(argv=None):
257236
parser = argparse.ArgumentParser()
258237
parser.add_argument('--n', type=int, default=128, help='n')
259238
parser.add_argument('--c', type=int, default=128, help='c')
@@ -274,6 +253,25 @@ def main(n=128, c=128, h=64, w=64, f=128, k=3, s=1, d=1, p=1, use_autotune=True,
274253
action="store_true",
275254
default=True,
276255
help="Whether to enable BitBLAS roller for search space")
277-
args = parser.parse_args()
278-
main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune,
279-
args.with_roller)
256+
257+
args = parser.parse_args(argv)
258+
N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p
259+
a = torch.randn(N, H, W, C).cuda().half()
260+
b = torch.randn(K, K, C, F).cuda().half()
261+
use_autotune = args.use_autotune
262+
with_roller = args.with_roller
263+
if use_autotune:
264+
result = get_best_config(N, C, H, W, F, K, S, D, P, with_roller)
265+
print(f"best latency {result.latency}")
266+
kernel = result.kernel
267+
else:
268+
kernel = tilelang.compile(
269+
convolution(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256), out_idx=[2])
270+
271+
out_c = kernel(a, b)
272+
ref_c = ref_program(S, P, D)(a, b)
273+
torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2)
274+
275+
276+
if __name__ == "__main__":
277+
main()

examples/convolution/test_example_convolution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
@tilelang.testing.requires_cuda
99
def test_example_convolution():
10-
example_convolution.main()
10+
example_convolution.main([])
1111

1212

1313
if __name__ == "__main__":

0 commit comments

Comments
 (0)