diff --git a/iree/turbine/kernel/wave/wave.py b/iree/turbine/kernel/wave/wave.py index 21485ed1f..249dd0dff 100644 --- a/iree/turbine/kernel/wave/wave.py +++ b/iree/turbine/kernel/wave/wave.py @@ -10,6 +10,7 @@ from ..compiler import builder, dispatch_codegen, kernel_codegen, host_codegen from ..compiler.ir import Context, Operation +from ..compiler.builder import ModuleBuilder from .codegen import WaveEmitter from .constraints import ( Constraint, @@ -288,8 +289,24 @@ def _trace_and_get_kernel_signature( if kwargs.get("canonicalize", False): canonicalize_module(mb.module_op) + if kwargs.get("add_host_codegen_call", False): + self.add_host_codegen_call(mb, exe, kernel_sig, entrypoint_name, kwargs) + return mb, graph, exe, kernel_sig, entrypoint_name + def add_host_codegen_call( + self, + mb: ModuleBuilder, + exe: dispatch_codegen.StreamExecutable, + kernel_sig: kernel_codegen.KernelSignature, + entrypoint_name: str, + kwargs, + ): + dynamic_symbols = kwargs.get("dynamic_symbols", []) + host_codegen.isolated_test_call( + mb, exe, kernel_sig, entrypoint_name, dynamic_symbols + ) + def test_execute(self, args, kwargs): ( mb, @@ -303,10 +320,7 @@ def test_execute(self, args, kwargs): run_bench = kwargs.get("run_bench", False) if run or run_bench: # TODO: cache compiled code - dynamic_symbols = kwargs.get("dynamic_symbols", []) - host_codegen.isolated_test_call( - mb, exe, kernel_sig, entrypoint_name, dynamic_symbols - ) + self.add_host_codegen_call(mb, exe, kernel_sig, entrypoint_name, kwargs) asm = mb.module_op.get_asm() kernel_inputs = []