1616from ..virtualized import V
1717from .aoti_hipify_utils import maybe_hipify_code_wrapper
1818from .codegen_device_driver import cuda_kernel_driver , cuda_kernel_header
19- from .cpp_utils import DTYPE_TO_CPP
19+ from .cpp_utils import cexpr , DTYPE_TO_CPP
2020from .cpp_wrapper_cpu import CppWrapperCpu
2121from .wrapper import SymbolicCallArg
2222
@@ -61,6 +61,98 @@ def _new_line(self, line):
6161 return DeferredCudaKernelLine (self .kernel_name , line , self .keys )
6262
6363
64+ class DeferredCudaDefaultGrid :
65+ """
66+ A marker to
67+ """
68+
69+ def __init__ (
70+ self ,
71+ kernel_name : str ,
72+ grid ,
73+ grid_callable : Optional [Callable [..., Any ]] = None ,
74+ ** grid_extra_kwargs ,
75+ ):
76+ self .kernel_name = kernel_name
77+ self .grid = grid
78+ self .grid_callable = grid_callable
79+ self .grid_extra_kwargs = grid_extra_kwargs
80+
81+ def __call__ (self ):
82+ grid = self .grid
83+ assert isinstance (grid , (list , tuple )), f"expected { grid = } to be a list"
84+ grid = [e .inner_expr if isinstance (e , SymbolicCallArg ) else e for e in grid ]
85+ grid_callable = self .grid_callable or default_grid
86+ if not self .grid_extra_kwargs :
87+ grid_fn = grid_callable (* grid )
88+ else :
89+ grid_fn = grid_callable (* grid , ** self .grid_extra_kwargs )
90+
91+ params = CudaKernelParamCache .get (self .kernel_name )
92+ assert (
93+ params is not None
94+ ), f"{ self .kernel_name } not found in CudaKernelParamCache"
95+ block_cfg = {
96+ "XBLOCK" : params ["x_block" ],
97+ "YBLOCK" : params ["y_block" ],
98+ "ZBLOCK" : params ["z_block" ],
99+ }
100+ return grid_fn (block_cfg )
101+
102+
103+ class DeferredCudaGridLine (DeferredLineBase ):
104+ """
105+ When using cpp wrapper, CUDA kernel load and launch needs to wait for Triton kernels
106+ to be tuned and stored as cubin files, so use a deferred line to backfill those information
107+ """
108+
109+ def __init__ (
110+ self ,
111+ kernel_name : str ,
112+ grid_var : str ,
113+ grid ,
114+ autotune_configs ,
115+ ):
116+ super ().__init__ ("" )
117+ self .kernel_name = kernel_name
118+ self .grid_var = grid_var
119+ self .grid = grid
120+ self .autotune_configs = autotune_configs
121+
122+ def __call__ (self ):
123+ params = CudaKernelParamCache .get (self .kernel_name )
124+ assert (
125+ params is not None
126+ ), f"{ self .kernel_name } not found in CudaKernelParamCache"
127+
128+ if self .autotune_configs is not None :
129+ # This indicates the Triton kernel is a user-defined one.
130+ grid = None
131+ if len (self .grid ) == 1 :
132+ grid = self .grid [0 ]
133+ else :
134+ for i , c in enumerate (self .autotune_configs ):
135+ if all (arg == params ["meta" ][key ] for key , arg in c .kwargs .items ()):
136+ grid = self .grid [i ]
137+ break
138+ assert grid is not None
139+ elif isinstance (self .grid , DeferredCudaDefaultGrid ):
140+ grid = self .grid ()
141+ else :
142+ grid = self .grid
143+
144+ assert len (grid ) != 0 , "Grid can't be empty"
145+ grid_args_str = ", " .join (
146+ [cexpr (V .graph .sizevars .simplify (item )) for item in grid ]
147+ )
148+ return f"Grid { self .grid_var } = Grid({ grid_args_str } );"
149+
150+ def _new_line (self , line ):
151+ return DeferredCudaGridLine (
152+ self .kernel_name , self .grid_var , self .grid , self .autotune_configs
153+ )
154+
155+
64156class CppWrapperCuda (CppWrapperCpu ):
65157 """
66158 Generates cpp wrapper for running on GPU and calls CUDA kernels
@@ -116,28 +208,20 @@ def generate(self, is_inference):
116208 return super ().generate (is_inference )
117209
118210 def generate_user_defined_triton_kernel (
119- self , kernel_name , raw_args , grid , configs , triton_meta , constexprs
211+ self ,
212+ kernel_name : str ,
213+ raw_args : List [Any ],
214+ grid : List [Any ],
215+ configs ,
216+ triton_meta ,
217+ constexprs ,
120218 ):
121219 # in C++ wrapper, we don't pass constexpr args, as they don't
122220 # get added as parameters to the PTX code compiled from the
123221 # user-defined Triton kernel (only non-constexpr args do)
124222 raw_args = [
125223 raw_arg for i , raw_arg in enumerate (raw_args ) if i not in constexprs
126224 ]
127-
128- assert len (grid ) != 0
129- if len (grid ) == 1 :
130- grid_decision = grid [0 ]
131- else :
132- meta = CudaKernelParamCache .get (kernel_name )
133- assert meta is not None
134- grid_decision = None
135- for i , c in enumerate (configs ):
136- if all (arg == meta ["meta" ][key ] for key , arg in c .kwargs .items ()):
137- grid_decision = grid [i ]
138- break
139- assert grid_decision is not None
140-
141225 args = [self .val_to_arg_str (v ) for v in raw_args ]
142226 arg_types = [
143227 arg .get_dtype () if hasattr (arg , "get_dtype" ) else type (arg )
@@ -147,10 +231,12 @@ def generate_user_defined_triton_kernel(
147231 kernel_name ,
148232 args ,
149233 arg_types = arg_types ,
150- grid = grid_decision ,
234+ raw_args = raw_args ,
235+ grid = grid ,
151236 cuda = True ,
152237 triton = True ,
153238 triton_meta = triton_meta ,
239+ autotune_configs = configs ,
154240 )
155241
156242 @functools .lru_cache (None ) # noqa: B019
@@ -228,39 +314,27 @@ def generate_args_decl(self, call_args, arg_types):
228314
229315 def generate_default_grid (
230316 self ,
231- name : str ,
317+ kernel_name : str ,
232318 grid : List [Any ],
233319 cuda : bool = True ,
234320 grid_callable : Optional [Callable [..., Any ]] = None ,
235321 ** grid_extra_kwargs ,
236322 ):
237323 """
238324 Generate grid configs for launching a CUDA kernel using the grid
239- function from triton_heuristics.
325+ function from triton_heuristics. Because its computation needs
326+ to read kernel config after autotune, it is done in a deferred way
327+ using DeferredCudaDefaultGrid.
240328 """
241329 if not cuda :
242330 return grid
243- assert isinstance (grid , (list , tuple )), f"expected { grid = } to be a list"
244- grid = [e .inner_expr if isinstance (e , SymbolicCallArg ) else e for e in grid ]
245- grid_callable = grid_callable or default_grid
246- if not grid_extra_kwargs :
247- grid_fn = grid_callable (* grid )
248- else :
249- grid_fn = grid_callable (* grid , ** grid_extra_kwargs )
250- params = CudaKernelParamCache .get (name )
251- assert (
252- params is not None
253- ), f"cuda kernel parameters for { name } should already exist at this moment, only found { CudaKernelParamCache .get_keys ()} "
254- block_cfg = {
255- "XBLOCK" : params ["x_block" ],
256- "YBLOCK" : params ["y_block" ],
257- "ZBLOCK" : params ["z_block" ],
258- }
259- return grid_fn (block_cfg )
331+ return DeferredCudaDefaultGrid (
332+ kernel_name , grid , grid_callable , ** grid_extra_kwargs
333+ )
260334
261335 def generate_kernel_call (
262336 self ,
263- kernel_name ,
337+ kernel_name : str ,
264338 call_args ,
265339 grid = None ,
266340 device_index = None ,
@@ -270,6 +344,7 @@ def generate_kernel_call(
270344 raw_args = None ,
271345 grid_fn : str = "grid" ,
272346 triton_meta = None ,
347+ autotune_configs = None ,
273348 grid_extra_kwargs = "" ,
274349 ):
275350 assert arg_types is not None and len (call_args ) == len (
@@ -279,7 +354,18 @@ def generate_kernel_call(
279354 if not cuda :
280355 # Even in CppWrapperCuda, we may see cpp kernels
281356 return super ().generate_kernel_call (
282- kernel_name , call_args , grid , device_index , cuda , triton , arg_types
357+ kernel_name ,
358+ call_args ,
359+ grid ,
360+ device_index ,
361+ cuda ,
362+ triton ,
363+ arg_types ,
364+ raw_args ,
365+ grid_fn ,
366+ triton_meta ,
367+ autotune_configs ,
368+ grid_extra_kwargs ,
283369 )
284370
285371 device_index , call_args = self .prepare_triton_kernel_call (
@@ -307,33 +393,26 @@ def generate_kernel_call(
307393 if V .graph .aot_mode
308394 else self .write_get_raw_stream (device_index , V .graph )
309395 )
310- grid_name = f"{ kernel_name } _grid_{ next (self .grid_id )} "
311- assert isinstance (
312- grid , (list , tuple )
313- ), f"expected grid to be a list or tuple but got: { grid = } "
314-
315- grid = [V .graph .sizevars .simplify (item ) for item in grid ]
316- grid_uses_symbolic_shapes = any (item .free_symbols for item in grid )
317- grid_args = [self .expr_printer (item ) for item in grid ]
318- grid_args_str = ", " .join (grid_args )
319- self .writeline (f"Grid { grid_name } = Grid({ grid_args_str } );" )
320-
321- if grid_uses_symbolic_shapes :
322- self .writeline (f"if ({ grid_name } .is_non_zero()) {{" )
396+
397+ grid_var = f"{ kernel_name } _grid_{ next (self .grid_id )} "
398+ self .writeline (
399+ DeferredCudaGridLine (kernel_name , grid_var , grid , autotune_configs )
400+ )
401+
323402 kernel_var_name = f"kernels.{ kernel_name } " if V .graph .aot_mode else kernel_name
403+ self .writeline (f"if ({ grid_var } .is_non_zero()) {{" )
324404 self .writeline (
325405 DeferredCudaKernelLine (
326406 kernel_name ,
327407 r"launchKernel({}, {}, {}, {}, %s, %s, {}, {});" .format (
328408 kernel_var_name ,
329- f"{ grid_name } .grid_x" ,
330- f"{ grid_name } .grid_y" ,
331- f"{ grid_name } .grid_z" ,
409+ f"{ grid_var } .grid_x" ,
410+ f"{ grid_var } .grid_y" ,
411+ f"{ grid_var } .grid_z" ,
332412 kernel_args_var ,
333413 stream ,
334414 ),
335415 ("num_warps" , "shared_mem" ),
336416 ),
337417 )
338- if grid_uses_symbolic_shapes :
339- self .writeline ("}" )
418+ self .writeline ("}" )
0 commit comments