@@ -251,6 +251,19 @@ def _compile(*config_arg):
251251 if self .jit_compile is None :
252252 self .jit_compile = _compile
253253
254+ # Factory functions for generating input tensors.
255+ # This encapsulates the logic of using either a custom supply program (`supply_prog`)
256+ # or the default profiler input generation (`profiler._get_inputs`).
257+ def get_input_tensors_supply (supply_prog , profiler , with_output : bool ):
258+
259+ def func ():
260+ if supply_prog is not None :
261+ return supply_prog (profiler ._get_params (with_output = with_output ))
262+ else :
263+ return profiler ._get_inputs (with_output = with_output )
264+
265+ return func
266+
254267 def target_fn (jit_context : JITContext ):
255268 # Unpack the context
256269 kernel = jit_context .kernel
@@ -266,57 +279,30 @@ def target_fn(jit_context: JITContext):
266279
267280 profiler = kernel .get_profiler (tensor_supply_type = supply_type )
268281
269- # Factory functions for generating input tensors.
270- # This encapsulates the logic of using either a custom supply program (`supply_prog`)
271- # or the default profiler input generation (`profiler._get_inputs`).
272- def get_input_tensors_supply (with_output : bool ):
273-
274- def func ():
275- if supply_prog is not None :
276- return supply_prog (profiler ._get_params (with_output = with_output ))
277- else :
278- return profiler ._get_inputs (with_output = with_output )
279-
280- return func
281-
282- jit_input_tensors_supply = get_input_tensors_supply (with_output = False )
283- ref_input_tensors_supply = get_input_tensors_supply (with_output = False )
284-
285- if cache_input_tensors :
286- jit_input_tensors = jit_input_tensors_supply ()
287- if self .jit_input_tensors is not None :
288- if not check_tensor_list_compatibility (self .jit_input_tensors ,
289- jit_input_tensors ):
290- logger .warning (
291- "Incompatible input tensor properties detected between cached tensors and "
292- "tensors regenerated for the current configuration trial. "
293- "This can happen if different tuning configurations require different input shapes/dtypes "
294- "and input tensor caching is enabled.\n "
295- "To ensure fresh, compatible inputs are generated for every trial "
296- "you can disable caching by setting:\n "
297- " `cache_input_tensors=False`\n "
298- "within your `.set_compile_args(...)` call.\n " )
299- self .jit_input_tensors = jit_input_tensors
300- self .jit_input_tensors = jit_input_tensors
282+ if cache_input_tensors and self .jit_input_tensors is not None :
283+ jit_input_tensors = self .jit_input_tensors
301284 else :
302- self .jit_input_tensors = jit_input_tensors_supply ()
285+ jit_input_tensors_supply = get_input_tensors_supply (
286+ supply_prog , profiler , with_output = False )
287+ jit_input_tensors = jit_input_tensors_supply ()
303288
304289 if (not skip_check ) and (ref_prog is not None ):
305290 if manual_check_prog is not None :
306291 profiler .manual_assert_close (
307292 ref_prog ,
308- input_tensors = self . jit_input_tensors ,
293+ input_tensors = jit_input_tensors ,
309294 manual_check_prog = manual_check_prog )
310295 else :
311296 profiler .assert_allclose (
312297 ref_prog ,
313- input_tensors = self . jit_input_tensors ,
298+ input_tensors = jit_input_tensors ,
314299 rtol = rtol ,
315300 atol = atol ,
316301 max_mismatched_ratio = max_mismatched_ratio )
317- latency = profiler .do_bench (
318- warmup = warmup , rep = rep , input_tensors = self .jit_input_tensors )
302+ latency = profiler .do_bench (warmup = warmup , rep = rep , input_tensors = jit_input_tensors )
319303 if self .ref_latency_cache is None and ref_prog is not None :
304+ ref_input_tensors_supply = get_input_tensors_supply (
305+ supply_prog , profiler , with_output = False )
320306 self .ref_input_tensors = ref_input_tensors_supply ()
321307 self .ref_latency_cache = profiler .do_bench (
322308 ref_prog , n_warmup = warmup , n_repeat = rep , input_tensors = self .ref_input_tensors )
@@ -369,6 +355,14 @@ def device_wrapper(func, device, *config_arg):
369355 continue
370356
371357 ref_latency = None
358+ if results_with_configs [0 ][0 ].cache_input_tensors :
359+ supply_prog = results_with_configs [0 ][0 ].supply_prog
360+ supply_type = results_with_configs [0 ][0 ].supply_type
361+ profiler = results_with_configs [0 ][0 ].kernel .get_profiler (
362+ tensor_supply_type = supply_type )
363+ jit_input_tensors_supply = get_input_tensors_supply (
364+ supply_prog , profiler , with_output = False )
365+ self .jit_input_tensors = jit_input_tensors_supply ()
372366 progress_bar = tqdm (range (len (results_with_configs )), desc = "Bench configurations" )
373367 for i in progress_bar :
374368 jit_context , config = results_with_configs [i ]
0 commit comments