Skip to content

Commit 1b95576

Browse files
authored
[Bugfix] Use AutoTune cache_input_tensors properly (#483)
1 parent 961df37 commit 1b95576

File tree

1 file changed

+31
-37
lines changed

1 file changed

+31
-37
lines changed

tilelang/autotuner/__init__.py

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,19 @@ def _compile(*config_arg):
251251
if self.jit_compile is None:
252252
self.jit_compile = _compile
253253

254+
# Factory functions for generating input tensors.
255+
# This encapsulates the logic of using either a custom supply program (`supply_prog`)
256+
# or the default profiler input generation (`profiler._get_inputs`).
257+
def get_input_tensors_supply(supply_prog, profiler, with_output: bool):
258+
259+
def func():
260+
if supply_prog is not None:
261+
return supply_prog(profiler._get_params(with_output=with_output))
262+
else:
263+
return profiler._get_inputs(with_output=with_output)
264+
265+
return func
266+
254267
def target_fn(jit_context: JITContext):
255268
# Unpack the context
256269
kernel = jit_context.kernel
@@ -266,57 +279,30 @@ def target_fn(jit_context: JITContext):
266279

267280
profiler = kernel.get_profiler(tensor_supply_type=supply_type)
268281

269-
# Factory functions for generating input tensors.
270-
# This encapsulates the logic of using either a custom supply program (`supply_prog`)
271-
# or the default profiler input generation (`profiler._get_inputs`).
272-
def get_input_tensors_supply(with_output: bool):
273-
274-
def func():
275-
if supply_prog is not None:
276-
return supply_prog(profiler._get_params(with_output=with_output))
277-
else:
278-
return profiler._get_inputs(with_output=with_output)
279-
280-
return func
281-
282-
jit_input_tensors_supply = get_input_tensors_supply(with_output=False)
283-
ref_input_tensors_supply = get_input_tensors_supply(with_output=False)
284-
285-
if cache_input_tensors:
286-
jit_input_tensors = jit_input_tensors_supply()
287-
if self.jit_input_tensors is not None:
288-
if not check_tensor_list_compatibility(self.jit_input_tensors,
289-
jit_input_tensors):
290-
logger.warning(
291-
"Incompatible input tensor properties detected between cached tensors and "
292-
"tensors regenerated for the current configuration trial. "
293-
"This can happen if different tuning configurations require different input shapes/dtypes "
294-
"and input tensor caching is enabled.\n"
295-
"To ensure fresh, compatible inputs are generated for every trial "
296-
"you can disable caching by setting:\n"
297-
" `cache_input_tensors=False`\n"
298-
"within your `.set_compile_args(...)` call.\n")
299-
self.jit_input_tensors = jit_input_tensors
300-
self.jit_input_tensors = jit_input_tensors
282+
if cache_input_tensors and self.jit_input_tensors is not None:
283+
jit_input_tensors = self.jit_input_tensors
301284
else:
302-
self.jit_input_tensors = jit_input_tensors_supply()
285+
jit_input_tensors_supply = get_input_tensors_supply(
286+
supply_prog, profiler, with_output=False)
287+
jit_input_tensors = jit_input_tensors_supply()
303288

304289
if (not skip_check) and (ref_prog is not None):
305290
if manual_check_prog is not None:
306291
profiler.manual_assert_close(
307292
ref_prog,
308-
input_tensors=self.jit_input_tensors,
293+
input_tensors=jit_input_tensors,
309294
manual_check_prog=manual_check_prog)
310295
else:
311296
profiler.assert_allclose(
312297
ref_prog,
313-
input_tensors=self.jit_input_tensors,
298+
input_tensors=jit_input_tensors,
314299
rtol=rtol,
315300
atol=atol,
316301
max_mismatched_ratio=max_mismatched_ratio)
317-
latency = profiler.do_bench(
318-
warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors)
302+
latency = profiler.do_bench(warmup=warmup, rep=rep, input_tensors=jit_input_tensors)
319303
if self.ref_latency_cache is None and ref_prog is not None:
304+
ref_input_tensors_supply = get_input_tensors_supply(
305+
supply_prog, profiler, with_output=False)
320306
self.ref_input_tensors = ref_input_tensors_supply()
321307
self.ref_latency_cache = profiler.do_bench(
322308
ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors)
@@ -369,6 +355,14 @@ def device_wrapper(func, device, *config_arg):
369355
continue
370356

371357
ref_latency = None
358+
if results_with_configs[0][0].cache_input_tensors:
359+
supply_prog = results_with_configs[0][0].supply_prog
360+
supply_type = results_with_configs[0][0].supply_type
361+
profiler = results_with_configs[0][0].kernel.get_profiler(
362+
tensor_supply_type=supply_type)
363+
jit_input_tensors_supply = get_input_tensors_supply(
364+
supply_prog, profiler, with_output=False)
365+
self.jit_input_tensors = jit_input_tensors_supply()
372366
progress_bar = tqdm(range(len(results_with_configs)), desc="Bench configurations")
373367
for i in progress_bar:
374368
jit_context, config = results_with_configs[i]

0 commit comments

Comments
 (0)