@@ -232,28 +232,7 @@ def main(
232232 return main
233233
234234
235- def main (n = 128 , c = 128 , h = 64 , w = 64 , f = 128 , k = 3 , s = 1 , d = 1 , p = 1 , use_autotune = True , with_roller = True ):
236- N , C , H , W , F , K , S , D , P = n , c , h , w , f , k , s , d , p
237- a = torch .randn (N , H , W , C ).cuda ().half ()
238- b = torch .randn (K , K , C , F ).cuda ().half ()
239- use_autotune = use_autotune
240- with_roller = with_roller
241- if use_autotune :
242- result = get_best_config (N , C , H , W , F , K , S , D , P , with_roller )
243- print (f"best latency { result .latency } " )
244- kernel = result .kernel
245- else :
246- kernel = tilelang .compile (
247- convolution (N , C , H , W , F , K , S , D , P , 64 , 128 , 32 , 3 , 256 ), out_idx = [2 ])
248-
249- out_c = kernel (a , b )
250- ref_c = ref_program (S , P , D )(a , b )
251- print (out_c )
252- print (ref_c )
253- # torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2)
254-
255-
256- if __name__ == "__main__" :
235+ def main (argv = None ):
257236 parser = argparse .ArgumentParser ()
258237 parser .add_argument ('--n' , type = int , default = 128 , help = 'n' )
259238 parser .add_argument ('--c' , type = int , default = 128 , help = 'c' )
@@ -274,6 +253,25 @@ def main(n=128, c=128, h=64, w=64, f=128, k=3, s=1, d=1, p=1, use_autotune=True,
274253 action = "store_true" ,
275254 default = True ,
276255 help = "Whether to enable BitBLAS roller for search space" )
277- args = parser .parse_args ()
278- main (args .n , args .c , args .h , args .w , args .f , args .k , args .s , args .d , args .p , args .use_autotune ,
279- args .with_roller )
256+
257+ args = parser .parse_args (argv )
258+ N , C , H , W , F , K , S , D , P = args .n , args .c , args .h , args .w , args .f , args .k , args .s , args .d , args .p
259+ a = torch .randn (N , H , W , C ).cuda ().half ()
260+ b = torch .randn (K , K , C , F ).cuda ().half ()
261+ use_autotune = args .use_autotune
262+ with_roller = args .with_roller
263+ if use_autotune :
264+ result = get_best_config (N , C , H , W , F , K , S , D , P , with_roller )
265+ print (f"best latency { result .latency } " )
266+ kernel = result .kernel
267+ else :
268+ kernel = tilelang .compile (
269+ convolution (N , C , H , W , F , K , S , D , P , 64 , 128 , 32 , 3 , 256 ), out_idx = [2 ])
270+
271+ out_c = kernel (a , b )
272+ ref_c = ref_program (S , P , D )(a , b )
273+ torch .testing .assert_close (out_c , ref_c , rtol = 1e-2 , atol = 1e-2 )
274+
275+
276+ if __name__ == "__main__" :
277+ main ()
0 commit comments