Skip to content

Commit 721baed

Browse files
authored
[Bugfix] Fix autotune cache (#1315)
1 parent 470eb74 commit 721baed

File tree

1 file changed

+153
-45
lines changed

1 file changed

+153
-45
lines changed

tilelang/autotuner/param.py

Lines changed: 153 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,25 @@
1313
from tilelang.jit import JITKernel
1414
import cloudpickle
1515
import os
16-
import shutil
1716
from tilelang.engine.param import KernelParam
1817
from tilelang import logger
1918
import json
2019
import hashlib
20+
import uuid
21+
from tilelang import env
22+
from tvm.runtime import Executable
2123

2224
BEST_CONFIG_PATH = "best_config.json"
2325
FUNCTION_PATH = "function.pkl"
2426
LATENCY_PATH = "latency.json"
25-
KERNEL_PATH = "kernel.cu"
26-
WRAPPED_KERNEL_PATH = "wrapped_kernel.cu"
27+
28+
# Align file names with cache/kernel_cache.py
29+
DEVICE_KERNEL_PATH = "device_kernel.cu"
30+
HOST_KERNEL_PATH = "host_kernel.cu"
31+
EXECUTABLE_PATH = "executable.so"
2732
KERNEL_LIB_PATH = "kernel_lib.so"
33+
KERNEL_CUBIN_PATH = "kernel.cubin"
34+
KERNEL_PY_PATH = "kernel.py"
2835
PARAMS_PATH = "params.pkl"
2936

3037

@@ -143,6 +150,31 @@ class AutotuneResult:
143150
func: Callable | None = None
144151
kernel: Callable | None = None
145152

153+
@staticmethod
154+
def _load_binary(path: str):
155+
with open(path, "rb") as file:
156+
binary = file.read()
157+
return binary
158+
159+
@staticmethod
160+
def _safe_write_file(path: str, mode: str, operation: Callable[[Any], None]):
161+
# Random a temporary file within the same FS as the cache directory
162+
tmp_dir = env.TILELANG_TMP_DIR
163+
os.makedirs(tmp_dir, exist_ok=True)
164+
temp_path = os.path.join(tmp_dir, f"{os.getpid()}_{uuid.uuid4()}")
165+
with open(temp_path, mode) as temp_file:
166+
operation(temp_file)
167+
# Use atomic POSIX replace, so other processes cannot see a partial write
168+
os.replace(temp_path, path)
169+
170+
@staticmethod
171+
def _safe_write_executable(executable: Executable, path: str):
172+
tmp_dir = env.TILELANG_TMP_DIR
173+
os.makedirs(tmp_dir, exist_ok=True)
174+
temp_path = os.path.join(tmp_dir, f"{os.getpid()}_{uuid.uuid4()}.so")
175+
executable.export_library(temp_path)
176+
os.replace(temp_path, path)
177+
146178
def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: bool = False):
147179
"""
148180
Persists a compiled kernel to disk cache.
@@ -161,34 +193,68 @@ def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: boo
161193
"""
162194
os.makedirs(cache_path, exist_ok=True) # Ensure directory exists
163195

164-
# Save kernel source code
196+
# Save device kernel source code
165197
try:
166-
kernel_path = os.path.join(cache_path, KERNEL_PATH)
198+
device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH)
167199
if verbose:
168-
logger.debug(f"Saving kernel source code to file: {kernel_path}")
200+
logger.debug(f"Saving kernel source code to file: {device_kernel_path}")
169201
if kernel.kernel_source is not None:
170-
with open(kernel_path, "w") as f:
171-
f.write(kernel.kernel_source)
202+
self._safe_write_file(device_kernel_path, "w",
203+
lambda f: f.write(kernel.kernel_source))
172204
except Exception as e:
173205
logger.error(f"Error saving kernel source code to disk: {e}")
174206

175-
# Save wrapped kernel source code
207+
# Save host kernel source code (wrapped)
176208
try:
177-
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
209+
host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH)
178210
if verbose:
179-
logger.debug(f"Saving wrapped kernel source code to file: {wrapped_kernel_path}")
180-
with open(wrapped_kernel_path, "w") as f:
181-
f.write(kernel.get_kernel_source())
211+
logger.debug(f"Saving wrapped kernel source code to file: {host_kernel_path}")
212+
# Match kernel_cache behavior: use host source for tvm_ffi, otherwise wrapped kernel
213+
if kernel.execution_backend == "tvm_ffi":
214+
self._safe_write_file(host_kernel_path, "w",
215+
lambda f: f.write(kernel.adapter.get_host_source()))
216+
else:
217+
self._safe_write_file(host_kernel_path, "w",
218+
lambda f: f.write(kernel.adapter.get_kernel_source()))
182219
except Exception as e:
183220
logger.error(f"Error saving wrapped kernel source code to disk: {e}")
184221

185-
# Save kernel library
222+
# Save kernel library (backend-specific)
186223
try:
187-
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)
188-
src_lib_path = kernel.adapter.libpath
189-
if verbose:
190-
logger.debug(f"Saving kernel library to file: {kernel_lib_path}")
191-
shutil.copy(src_lib_path, kernel_lib_path)
224+
if kernel.execution_backend == "nvrtc":
225+
kernel_lib_file = KERNEL_CUBIN_PATH
226+
elif kernel.execution_backend == "tvm_ffi":
227+
kernel_lib_file = EXECUTABLE_PATH
228+
else:
229+
kernel_lib_file = KERNEL_LIB_PATH
230+
231+
kernel_lib_path = os.path.join(cache_path, kernel_lib_file)
232+
233+
if kernel.execution_backend == "nvrtc":
234+
# Save cubin and python helper file
235+
src_lib_path = kernel.adapter.libpath
236+
kernel_py_path = os.path.join(cache_path, KERNEL_PY_PATH)
237+
py_src_path = src_lib_path.replace(".cubin", ".py")
238+
if verbose:
239+
logger.debug(f"Saving kernel nvrtc python code to file: {kernel_py_path}")
240+
self._safe_write_file(kernel_py_path, "wb",
241+
lambda f: f.write(self._load_binary(py_src_path)))
242+
if verbose:
243+
logger.debug(f"Saving kernel library to file: {kernel_lib_path}")
244+
self._safe_write_file(kernel_lib_path, "wb",
245+
lambda f: f.write(self._load_binary(src_lib_path)))
246+
elif kernel.execution_backend == "tvm_ffi":
247+
executable = kernel.adapter.executable
248+
if verbose:
249+
logger.debug(f"Saving kernel executable to file: {kernel_lib_path}")
250+
self._safe_write_executable(executable, kernel_lib_path)
251+
else:
252+
src_lib_path = kernel.adapter.libpath
253+
if verbose:
254+
logger.debug(f"Saving kernel library to file: {kernel_lib_path}")
255+
self._safe_write_file(kernel_lib_path, "wb",
256+
lambda f: f.write(self._load_binary(src_lib_path)))
257+
192258
except Exception as e:
193259
logger.error(f"Error saving kernel library to disk: {e}")
194260

@@ -197,8 +263,7 @@ def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: boo
197263
params_path = os.path.join(cache_path, PARAMS_PATH)
198264
if verbose:
199265
logger.debug(f"Saving kernel parameters to disk: {params_path}")
200-
with open(params_path, "wb") as f:
201-
cloudpickle.dump(kernel.params, f)
266+
self._safe_write_file(params_path, "wb", lambda f: cloudpickle.dump(kernel.params, f))
202267
except Exception as e:
203268
logger.error(f"Error saving kernel parameters to disk: {e}")
204269

@@ -210,6 +275,7 @@ def _load_kernel_from_disk(
210275
out_idx: list[int] | int | None = None,
211276
execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi",
212277
pass_configs: dict = None,
278+
compile_flags: list[str] | str | None = None,
213279
func: Callable = None,
214280
verbose: bool = False,
215281
) -> JITKernel:
@@ -233,41 +299,66 @@ def _load_kernel_from_disk(
233299
if not os.path.exists(cache_path):
234300
return None
235301

236-
kernel_global_source: str | None = None
302+
# Resolve backend to pick correct file names
303+
if execution_backend == "nvrtc":
304+
kernel_lib_file = KERNEL_CUBIN_PATH
305+
elif execution_backend == "tvm_ffi":
306+
kernel_lib_file = EXECUTABLE_PATH
307+
else:
308+
kernel_lib_file = KERNEL_LIB_PATH
309+
310+
device_kernel_path = os.path.join(cache_path, DEVICE_KERNEL_PATH)
311+
host_kernel_path = os.path.join(cache_path, HOST_KERNEL_PATH)
312+
kernel_lib_path = os.path.join(cache_path, kernel_lib_file)
313+
params_path = os.path.join(cache_path, PARAMS_PATH)
314+
315+
if not all([os.path.exists(file) for file in (kernel_lib_path, params_path)]):
316+
return None
317+
318+
device_kernel_source: str | None = None
319+
host_kernel_source: str | None = None
237320
kernel_params: list[KernelParam] | None = None
238321

322+
# Load optional device kernel source
239323
try:
240-
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
241324
if verbose:
242-
logger.debug(f"Loading wrapped kernel source code from file: {wrapped_kernel_path}")
243-
with open(wrapped_kernel_path) as f:
244-
kernel_global_source = f.read()
325+
logger.debug(f"Loading kernel source code from file: {device_kernel_path}")
326+
with open(device_kernel_path) as f:
327+
device_kernel_source = f.read()
245328
except Exception as e:
246-
logger.error(f"Error loading wrapped kernel source code from disk: {e}")
329+
logger.error(f"Error loading kernel source code from disk: {e}")
247330

248-
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)
331+
# Load optional host kernel source
332+
try:
333+
if verbose:
334+
logger.debug(f"Loading wrapped kernel source code from file: {host_kernel_path}")
335+
with open(host_kernel_path) as f:
336+
host_kernel_source = f.read()
337+
except Exception as e:
338+
logger.error(f"Error loading host kernel source code from disk: {e}")
249339

250340
# Load kernel parameters
251341
try:
252-
params_path = os.path.join(cache_path, PARAMS_PATH)
253342
if verbose:
254343
logger.debug(f"Loading kernel parameters from file: {params_path}")
255344
with open(params_path, "rb") as f:
256345
kernel_params = cloudpickle.load(f)
257346
except Exception as e:
258347
logger.error(f"Error loading kernel parameters from disk: {e}")
259348

260-
if kernel_global_source and kernel_params:
349+
if host_kernel_source and device_kernel_source and kernel_params:
261350
return JITKernel.from_database(
262351
func=func,
263-
kernel_global_source=kernel_global_source,
352+
host_kernel_source=host_kernel_source,
353+
device_kernel_source=device_kernel_source,
264354
kernel_lib_path=kernel_lib_path,
265355
params=kernel_params,
266356
target=target,
267357
target_host=target_host,
268358
out_idx=out_idx,
269359
execution_backend=execution_backend,
270360
pass_configs=pass_configs,
361+
compile_flags=compile_flags,
271362
)
272363
else:
273364
return None
@@ -276,26 +367,29 @@ def save_to_disk(self, path: Path, verbose: bool = False):
276367
if not os.path.exists(path):
277368
os.makedirs(path)
278369

279-
# save best config
370+
# save best config (atomic)
280371
if verbose:
281372
logger.debug(f"Saving best config to file: {path / BEST_CONFIG_PATH}")
282-
with open(path / BEST_CONFIG_PATH, "w") as f:
283-
json.dump(self.config, f)
373+
self._safe_write_file(
374+
str(path / BEST_CONFIG_PATH), "w", lambda f: json.dump(self.config, f))
284375

285-
# save function
376+
# save function (atomic)
286377
if verbose:
287378
logger.debug(f"Saving function to file: {path / FUNCTION_PATH}")
288-
with open(path / FUNCTION_PATH, "wb") as f:
289-
cloudpickle.dump(self.func, f)
379+
self._safe_write_file(
380+
str(path / FUNCTION_PATH), "wb", lambda f: cloudpickle.dump(self.func, f))
290381

291-
# save ref latency
382+
# save ref latency (atomic)
292383
if verbose:
293384
logger.debug(f"Saving latency to file: {path / LATENCY_PATH}")
294-
with open(path / LATENCY_PATH, "w") as f:
295-
json.dump({
385+
self._safe_write_file(
386+
str(path / LATENCY_PATH),
387+
"w",
388+
lambda f: json.dump({
296389
"latency": self.latency,
297390
"ref_latency": self.ref_latency,
298-
}, f)
391+
}, f),
392+
)
299393

300394
# save kernel
301395
self._save_kernel_to_disk(path, self.kernel)
@@ -306,6 +400,13 @@ def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> AutotuneResult
306400
return None
307401

308402
verbose = compile_args.verbose
403+
# Normalize target and resolve execution backend for loading
404+
from tilelang.utils.target import determine_target as _determine_target
405+
from tilelang.jit.execution_backend import resolve_execution_backend
406+
norm_target = Target(_determine_target(compile_args.target)) if isinstance(
407+
compile_args.target, str) else compile_args.target
408+
requested_backend = compile_args.execution_backend
409+
resolved_backend = resolve_execution_backend(requested_backend, norm_target)
309410
# load best config
310411
if verbose:
311412
logger.debug(f"Loading best config from file: {path / BEST_CONFIG_PATH}")
@@ -325,10 +426,17 @@ def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> AutotuneResult
325426
latency = json.load(f)
326427
latency, ref_latency = latency["latency"], latency["ref_latency"]
327428

328-
kernel = cls._load_kernel_from_disk(cls, path, compile_args.target,
329-
compile_args.target_host, compile_args.out_idx,
330-
compile_args.execution_backend,
331-
compile_args.pass_configs, func)
429+
kernel = cls._load_kernel_from_disk(
430+
cls,
431+
path,
432+
norm_target,
433+
compile_args.target_host,
434+
compile_args.out_idx,
435+
resolved_backend,
436+
compile_args.pass_configs,
437+
None, # compile_flags not tracked here
438+
func,
439+
)
332440
if kernel is None:
333441
return None
334442
kernel.update_tuner_result(

0 commit comments

Comments
 (0)