@@ -38,7 +38,7 @@ def check_if_then_else(dev, n, dtype):
3838 sch = tvm .tir .Schedule (func )
3939 (x ,) = sch .get_loops (sch .get_block ("C" ))
4040 sch .bind (x , "threadIdx.x" )
41- fun = tvm .compile (sch .mod , target = target )
41+ fun = tvm .tir . build (sch .mod , target = target )
4242 a = tvm .nd .empty ((n ,), A .dtype , dev )
4343 c = tvm .nd .empty ((n ,), A .dtype , dev )
4444 # Only need to test compiling here
@@ -55,7 +55,7 @@ def check_select(dev, n, dtype):
5555 sch = tvm .tir .Schedule (func )
5656 (x ,) = sch .get_loops (sch .get_block ("C" ))
5757 sch .bind (x , "threadIdx.x" )
58- fun = tvm .compile (sch .mod , target = target )
58+ fun = tvm .tir . build (sch .mod , target = target )
5959
6060 a = tvm .nd .empty ((n ,), A .dtype , dev )
6161 c = tvm .nd .empty ((n ,), A .dtype , dev )
@@ -85,7 +85,7 @@ def check_inf_nan(dev, n, value, dtype):
8585 sch = tvm .tir .Schedule (func )
8686 (x ,) = sch .get_loops (sch .get_block ("C" ))
8787 sch .bind (x , "threadIdx.x" )
88- fun = tvm .compile (sch .mod , target = target )
88+ fun = tvm .tir . build (sch .mod , target = target )
8989 a = tvm .nd .empty ((n ,), A .dtype , dev )
9090 c = tvm .nd .empty ((n ,), A .dtype , dev )
9191 # Only need to test compiling here
@@ -113,7 +113,7 @@ def check_max(dev, n, dtype):
113113 sch = tvm .tir .Schedule (func )
114114 (x ,) = sch .get_loops (sch .get_block ("C" ))
115115 sch .bind (x , "threadIdx.x" )
116- fun = tvm .compile (sch .mod , target = target )
116+ fun = tvm .tir . build (sch .mod , target = target )
117117
118118 a = tvm .nd .empty ((n ,), A .dtype , dev )
119119 c = tvm .nd .empty ((n ,), A .dtype , dev )
@@ -178,7 +178,7 @@ def check_type_casting(ctx, n, dtype):
178178 sch .bind (tx , "threadIdx.x" )
179179 sch .vectorize (vx )
180180
181- fun = tvm .compile (sch .mod , target = target )
181+ fun = tvm .tir . build (sch .mod , target = target )
182182 c = tvm .nd .empty ((n ,), dtype , ctx )
183183 assembly = fun .imported_modules [0 ].get_source ()
184184 lcond = "convert_int4(((convert_uint4(((uint4)(((convert_int(get_local_id(0))) == 3), ((convert_int(get_local_id(0))) == 3), ((convert_int(get_local_id(0))) == 3), ((convert_int(get_local_id(0))) == 3)))))"
@@ -210,7 +210,7 @@ def _check(target, n, dtype):
210210 (x ,) = sch .get_loops (sch .get_block ("C" ))
211211 sch .bind (x , "threadIdx.x" )
212212
213- fun = tvm .compile (sch .mod , target = target )
213+ fun = tvm .tir . build (sch .mod , target = target )
214214 assembly = fun .imported_modules [0 ].get_source ()
215215 if "adreno" in target :
216216 pattern = "convert_float"
@@ -225,7 +225,7 @@ def _get_maximum_kernel_args(source):
225225 def get_kernel_args (source ):
226226 import re
227227
228- p = re .compile (r"__kernel void .+\((.*)\)" )
228+ p = re .tir . build (r"__kernel void .+\((.*)\)" )
229229 args = p .findall (source )
230230 return args
231231
0 commit comments