@@ -251,19 +251,6 @@ def _compile(*config_arg):
251251 if self .jit_compile is None :
252252 self .jit_compile = _compile
253253
254- # Factory functions for generating input tensors.
255- # This encapsulates the logic of using either a custom supply program (`supply_prog`)
256- # or the default profiler input generation (`profiler._get_inputs`).
257- def get_input_tensors_supply (supply_prog , profiler , with_output : bool ):
258-
259- def func ():
260- if supply_prog is not None :
261- return supply_prog (profiler ._get_params (with_output = with_output ))
262- else :
263- return profiler ._get_inputs (with_output = with_output )
264-
265- return func
266-
267254 def target_fn (jit_context : JITContext ):
268255 # Unpack the context
269256 kernel = jit_context .kernel
@@ -279,30 +266,57 @@ def target_fn(jit_context: JITContext):
279266
280267 profiler = kernel .get_profiler (tensor_supply_type = supply_type )
281268
282- if cache_input_tensors and self .jit_input_tensors is not None :
283- jit_input_tensors = self .jit_input_tensors
284- else :
285- jit_input_tensors_supply = get_input_tensors_supply (
286- supply_prog , profiler , with_output = False )
269+ # Factory functions for generating input tensors.
270+ # This encapsulates the logic of using either a custom supply program (`supply_prog`)
271+ # or the default profiler input generation (`profiler._get_inputs`).
272+ def get_input_tensors_supply (with_output : bool ):
273+
274+ def func ():
275+ if supply_prog is not None :
276+ return supply_prog (profiler ._get_params (with_output = with_output ))
277+ else :
278+ return profiler ._get_inputs (with_output = with_output )
279+
280+ return func
281+
282+ jit_input_tensors_supply = get_input_tensors_supply (with_output = False )
283+ ref_input_tensors_supply = get_input_tensors_supply (with_output = False )
284+
285+ if cache_input_tensors :
287286 jit_input_tensors = jit_input_tensors_supply ()
287+ if self .jit_input_tensors is not None :
288+ if not check_tensor_list_compatibility (self .jit_input_tensors ,
289+ jit_input_tensors ):
290+ logger .warning (
291+ "Incompatible input tensor properties detected between cached tensors and "
292+ "tensors regenerated for the current configuration trial. "
293+ "This can happen if different tuning configurations require different input shapes/dtypes "
294+ "and input tensor caching is enabled.\n "
295+ "To ensure fresh, compatible inputs are generated for every trial "
296+ "you can disable caching by setting:\n "
297+ " `cache_input_tensors=False`\n "
298+ "within your `.set_compile_args(...)` call.\n " )
299+ self .jit_input_tensors = jit_input_tensors
300+ self .jit_input_tensors = jit_input_tensors
301+ else :
302+ self .jit_input_tensors = jit_input_tensors_supply ()
288303
289304 if (not skip_check ) and (ref_prog is not None ):
290305 if manual_check_prog is not None :
291306 profiler .manual_assert_close (
292307 ref_prog ,
293- input_tensors = jit_input_tensors ,
308+ input_tensors = self . jit_input_tensors ,
294309 manual_check_prog = manual_check_prog )
295310 else :
296311 profiler .assert_allclose (
297312 ref_prog ,
298- input_tensors = jit_input_tensors ,
313+ input_tensors = self . jit_input_tensors ,
299314 rtol = rtol ,
300315 atol = atol ,
301316 max_mismatched_ratio = max_mismatched_ratio )
302- latency = profiler .do_bench (warmup = warmup , rep = rep , input_tensors = jit_input_tensors )
317+ latency = profiler .do_bench (
318+ warmup = warmup , rep = rep , input_tensors = self .jit_input_tensors )
303319 if self .ref_latency_cache is None and ref_prog is not None :
304- ref_input_tensors_supply = get_input_tensors_supply (
305- supply_prog , profiler , with_output = False )
306320 self .ref_input_tensors = ref_input_tensors_supply ()
307321 self .ref_latency_cache = profiler .do_bench (
308322 ref_prog , n_warmup = warmup , n_repeat = rep , input_tensors = self .ref_input_tensors )
@@ -355,14 +369,6 @@ def device_wrapper(func, device, *config_arg):
355369 continue
356370
357371 ref_latency = None
358- if results_with_configs [0 ][0 ].cache_input_tensors :
359- supply_prog = results_with_configs [0 ][0 ].supply_prog
360- supply_type = results_with_configs [0 ][0 ].supply_type
361- profiler = results_with_configs [0 ][0 ].kernel .get_profiler (
362- tensor_supply_type = supply_type )
363- jit_input_tensors_supply = get_input_tensors_supply (
364- supply_prog , profiler , with_output = False )
365- self .jit_input_tensors = jit_input_tensors_supply ()
366372 progress_bar = tqdm (range (len (results_with_configs )), desc = "Bench configurations" )
367373 for i in progress_bar :
368374 jit_context , config = results_with_configs [i ]
0 commit comments