Skip to content

Commit 8770b6e

Browse files
authored
Revert "[Bugfix] Use AutoTune cache_input_tensors properly (#483)" (#488)
This reverts commit 1b95576.
1 parent 1b95576 commit 8770b6e

File tree

1 file changed

+37
-31
lines changed

1 file changed

+37
-31
lines changed

tilelang/autotuner/__init__.py

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -251,19 +251,6 @@ def _compile(*config_arg):
251251
if self.jit_compile is None:
252252
self.jit_compile = _compile
253253

254-
# Factory functions for generating input tensors.
255-
# This encapsulates the logic of using either a custom supply program (`supply_prog`)
256-
# or the default profiler input generation (`profiler._get_inputs`).
257-
def get_input_tensors_supply(supply_prog, profiler, with_output: bool):
258-
259-
def func():
260-
if supply_prog is not None:
261-
return supply_prog(profiler._get_params(with_output=with_output))
262-
else:
263-
return profiler._get_inputs(with_output=with_output)
264-
265-
return func
266-
267254
def target_fn(jit_context: JITContext):
268255
# Unpack the context
269256
kernel = jit_context.kernel
@@ -279,30 +266,57 @@ def target_fn(jit_context: JITContext):
279266

280267
profiler = kernel.get_profiler(tensor_supply_type=supply_type)
281268

282-
if cache_input_tensors and self.jit_input_tensors is not None:
283-
jit_input_tensors = self.jit_input_tensors
284-
else:
285-
jit_input_tensors_supply = get_input_tensors_supply(
286-
supply_prog, profiler, with_output=False)
269+
# Factory functions for generating input tensors.
270+
# This encapsulates the logic of using either a custom supply program (`supply_prog`)
271+
# or the default profiler input generation (`profiler._get_inputs`).
272+
def get_input_tensors_supply(with_output: bool):
273+
274+
def func():
275+
if supply_prog is not None:
276+
return supply_prog(profiler._get_params(with_output=with_output))
277+
else:
278+
return profiler._get_inputs(with_output=with_output)
279+
280+
return func
281+
282+
jit_input_tensors_supply = get_input_tensors_supply(with_output=False)
283+
ref_input_tensors_supply = get_input_tensors_supply(with_output=False)
284+
285+
if cache_input_tensors:
287286
jit_input_tensors = jit_input_tensors_supply()
287+
if self.jit_input_tensors is not None:
288+
if not check_tensor_list_compatibility(self.jit_input_tensors,
289+
jit_input_tensors):
290+
logger.warning(
291+
"Incompatible input tensor properties detected between cached tensors and "
292+
"tensors regenerated for the current configuration trial. "
293+
"This can happen if different tuning configurations require different input shapes/dtypes "
294+
"and input tensor caching is enabled.\n"
295+
"To ensure fresh, compatible inputs are generated for every trial "
296+
"you can disable caching by setting:\n"
297+
" `cache_input_tensors=False`\n"
298+
"within your `.set_compile_args(...)` call.\n")
299+
self.jit_input_tensors = jit_input_tensors
300+
self.jit_input_tensors = jit_input_tensors
301+
else:
302+
self.jit_input_tensors = jit_input_tensors_supply()
288303

289304
if (not skip_check) and (ref_prog is not None):
290305
if manual_check_prog is not None:
291306
profiler.manual_assert_close(
292307
ref_prog,
293-
input_tensors=jit_input_tensors,
308+
input_tensors=self.jit_input_tensors,
294309
manual_check_prog=manual_check_prog)
295310
else:
296311
profiler.assert_allclose(
297312
ref_prog,
298-
input_tensors=jit_input_tensors,
313+
input_tensors=self.jit_input_tensors,
299314
rtol=rtol,
300315
atol=atol,
301316
max_mismatched_ratio=max_mismatched_ratio)
302-
latency = profiler.do_bench(warmup=warmup, rep=rep, input_tensors=jit_input_tensors)
317+
latency = profiler.do_bench(
318+
warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors)
303319
if self.ref_latency_cache is None and ref_prog is not None:
304-
ref_input_tensors_supply = get_input_tensors_supply(
305-
supply_prog, profiler, with_output=False)
306320
self.ref_input_tensors = ref_input_tensors_supply()
307321
self.ref_latency_cache = profiler.do_bench(
308322
ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors)
@@ -355,14 +369,6 @@ def device_wrapper(func, device, *config_arg):
355369
continue
356370

357371
ref_latency = None
358-
if results_with_configs[0][0].cache_input_tensors:
359-
supply_prog = results_with_configs[0][0].supply_prog
360-
supply_type = results_with_configs[0][0].supply_type
361-
profiler = results_with_configs[0][0].kernel.get_profiler(
362-
tensor_supply_type=supply_type)
363-
jit_input_tensors_supply = get_input_tensors_supply(
364-
supply_prog, profiler, with_output=False)
365-
self.jit_input_tensors = jit_input_tensors_supply()
366372
progress_bar = tqdm(range(len(results_with_configs)), desc="Bench configurations")
367373
for i in progress_bar:
368374
jit_context, config = results_with_configs[i]

0 commit comments

Comments
 (0)