@@ -220,9 +220,9 @@ def create_dispatch_func(self, code, function_informations):
220220 raise ValueError (
221221 f"Parameter { param } is not in the buffer map of the primary function." )
222222 # Add dynamic symbols as integer arguments
223- for dyn_sym in dynamic_symbolic_set :
223+ for dyn_sym , dyn_sym_dtype in dynamic_symbolic_set :
224224 if dyn_sym not in [arg ["name" ] for arg in function_args ]:
225- function_args .append ({"name" : dyn_sym , "type" : "int" })
225+ function_args .append ({"name" : dyn_sym , "type" : self . _lookup_type ( dyn_sym_dtype ) })
226226
227227 function_args .append (self .get_stream_type ())
228228
@@ -405,28 +405,30 @@ def parse_source_information(self):
405405
406406 def get_dynamic_symbolic_set (self , prim_func ):
407407 # Determine the set of dynamic symbols used in the function
408- dynamic_symbolic_set : list [str ] = []
408+ dynamic_symbolic_set : dict [str , str ] = {}
409409
410- def unique_push_back (name : str ):
410+ def unique_push_back (name : str , dtype : str ):
411411 if name not in dynamic_symbolic_set :
412- dynamic_symbolic_set .append (name )
412+ dynamic_symbolic_set [name ] = dtype
413+ else :
414+ assert dtype == dynamic_symbolic_set [name ]
413415
414416 for param in prim_func .params :
415417 if param in prim_func .buffer_map :
416418 buffer = prim_func .buffer_map [param ]
417419 for dim in buffer .shape :
418420 if isinstance (dim , tvm .tir .Var ):
419- unique_push_back (dim .name )
421+ unique_push_back (dim .name , str ( dim . dtype ) )
420422
421423 # Note: In buffer definitions, any dynamic symbols appearing in strides are listed after those in the shape.
422424 for param in prim_func .params :
423425 if param in prim_func .buffer_map :
424426 buffer = prim_func .buffer_map [param ]
425427 for stride in buffer .strides :
426428 if isinstance (stride , tvm .tir .Var ):
427- unique_push_back (stride .name )
429+ unique_push_back (stride .name , str ( stride . dtype ) )
428430
429- return dynamic_symbolic_set
431+ return list ( dynamic_symbolic_set . items ())
430432
431433 def get_init_func (self ):
432434 # Initialize an empty string for the CUDA function call
@@ -665,8 +667,8 @@ def create_call_func(self, code, function_informations):
665667 raise ValueError (
666668 f"Parameter { param } is not in the buffer map of the primary function." )
667669 # Add dynamic symbols as integer arguments
668- for dyn_sym in dynamic_symbolic_set :
669- function_args .append ({"name" : dyn_sym , "type" : "int" })
670+ for dyn_sym , dyn_sym_dtype in dynamic_symbolic_set :
671+ function_args .append ({"name" : dyn_sym , "type" : self . _lookup_type ( dyn_sym_dtype ) })
670672 # Format the function arguments for declaration
671673 def_args = ", " .join ([f"{ arg ['type' ]} { arg ['name' ]} " for arg in function_args ])
672674
@@ -715,14 +717,14 @@ def parse_source_information(self):
715717
716718 def get_dynamic_symbolic_set (self , prim_func ):
717719 # Determine the set of dynamic symbols used in the function
718- dynamic_symbolic_set : list [str ] = []
720+ dynamic_symbolic_set : dict [str , str ] = {}
719721 for param in prim_func .params :
720722 if param in prim_func .buffer_map :
721723 buffer = prim_func .buffer_map [param ]
722724 for dim in buffer .shape :
723725 if isinstance (dim , tvm .tir .Var ) and (dim .name not in dynamic_symbolic_set ):
724- dynamic_symbolic_set . append (dim .name )
725- return dynamic_symbolic_set
726+ dynamic_symbolic_set [ dim . name ] = str (dim .dtype )
727+ return list ( dynamic_symbolic_set . items ())
726728
727729 def get_cpu_init_func (self ):
728730 # Provide init() and get_last_error() for CPU backend
0 commit comments