1313from tilelang .jit import JITKernel
1414import cloudpickle
1515import os
16- import shutil
1716from tilelang .engine .param import KernelParam
1817from tilelang import logger
1918import json
2019import hashlib
20+ import uuid
21+ from tilelang import env
22+ from tvm .runtime import Executable
2123
2224BEST_CONFIG_PATH = "best_config.json"
2325FUNCTION_PATH = "function.pkl"
2426LATENCY_PATH = "latency.json"
25- KERNEL_PATH = "kernel.cu"
26- WRAPPED_KERNEL_PATH = "wrapped_kernel.cu"
27+
28+ # Align file names with cache/kernel_cache.py
29+ DEVICE_KERNEL_PATH = "device_kernel.cu"
30+ HOST_KERNEL_PATH = "host_kernel.cu"
31+ EXECUTABLE_PATH = "executable.so"
2732KERNEL_LIB_PATH = "kernel_lib.so"
33+ KERNEL_CUBIN_PATH = "kernel.cubin"
34+ KERNEL_PY_PATH = "kernel.py"
2835PARAMS_PATH = "params.pkl"
2936
3037
@@ -143,6 +150,31 @@ class AutotuneResult:
143150 func : Callable | None = None
144151 kernel : Callable | None = None
145152
153+ @staticmethod
154+ def _load_binary (path : str ):
155+ with open (path , "rb" ) as file :
156+ binary = file .read ()
157+ return binary
158+
159+ @staticmethod
160+ def _safe_write_file (path : str , mode : str , operation : Callable [[Any ], None ]):
161+ # Random a temporary file within the same FS as the cache directory
162+ tmp_dir = env .TILELANG_TMP_DIR
163+ os .makedirs (tmp_dir , exist_ok = True )
164+ temp_path = os .path .join (tmp_dir , f"{ os .getpid ()} _{ uuid .uuid4 ()} " )
165+ with open (temp_path , mode ) as temp_file :
166+ operation (temp_file )
167+ # Use atomic POSIX replace, so other processes cannot see a partial write
168+ os .replace (temp_path , path )
169+
170+ @staticmethod
171+ def _safe_write_executable (executable : Executable , path : str ):
172+ tmp_dir = env .TILELANG_TMP_DIR
173+ os .makedirs (tmp_dir , exist_ok = True )
174+ temp_path = os .path .join (tmp_dir , f"{ os .getpid ()} _{ uuid .uuid4 ()} .so" )
175+ executable .export_library (temp_path )
176+ os .replace (temp_path , path )
177+
146178 def _save_kernel_to_disk (self , cache_path : Path , kernel : JITKernel , verbose : bool = False ):
147179 """
148180 Persists a compiled kernel to disk cache.
@@ -161,34 +193,68 @@ def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: boo
161193 """
162194 os .makedirs (cache_path , exist_ok = True ) # Ensure directory exists
163195
164- # Save kernel source code
196+ # Save device kernel source code
165197 try :
166- kernel_path = os .path .join (cache_path , KERNEL_PATH )
198+ device_kernel_path = os .path .join (cache_path , DEVICE_KERNEL_PATH )
167199 if verbose :
168- logger .debug (f"Saving kernel source code to file: { kernel_path } " )
200+ logger .debug (f"Saving kernel source code to file: { device_kernel_path } " )
169201 if kernel .kernel_source is not None :
170- with open ( kernel_path , "w" ) as f :
171- f .write (kernel .kernel_source )
202+ self . _safe_write_file ( device_kernel_path , "w" ,
203+ lambda f : f .write (kernel .kernel_source ) )
172204 except Exception as e :
173205 logger .error (f"Error saving kernel source code to disk: { e } " )
174206
175- # Save wrapped kernel source code
207+ # Save host kernel source code (wrapped)
176208 try :
177- wrapped_kernel_path = os .path .join (cache_path , WRAPPED_KERNEL_PATH )
209+ host_kernel_path = os .path .join (cache_path , HOST_KERNEL_PATH )
178210 if verbose :
179- logger .debug (f"Saving wrapped kernel source code to file: { wrapped_kernel_path } " )
180- with open (wrapped_kernel_path , "w" ) as f :
181- f .write (kernel .get_kernel_source ())
211+ logger .debug (f"Saving wrapped kernel source code to file: { host_kernel_path } " )
212+ # Match kernel_cache behavior: use host source for tvm_ffi, otherwise wrapped kernel
213+ if kernel .execution_backend == "tvm_ffi" :
214+ self ._safe_write_file (host_kernel_path , "w" ,
215+ lambda f : f .write (kernel .adapter .get_host_source ()))
216+ else :
217+ self ._safe_write_file (host_kernel_path , "w" ,
218+ lambda f : f .write (kernel .adapter .get_kernel_source ()))
182219 except Exception as e :
183220 logger .error (f"Error saving wrapped kernel source code to disk: { e } " )
184221
185- # Save kernel library
222+ # Save kernel library (backend-specific)
186223 try :
187- kernel_lib_path = os .path .join (cache_path , KERNEL_LIB_PATH )
188- src_lib_path = kernel .adapter .libpath
189- if verbose :
190- logger .debug (f"Saving kernel library to file: { kernel_lib_path } " )
191- shutil .copy (src_lib_path , kernel_lib_path )
224+ if kernel .execution_backend == "nvrtc" :
225+ kernel_lib_file = KERNEL_CUBIN_PATH
226+ elif kernel .execution_backend == "tvm_ffi" :
227+ kernel_lib_file = EXECUTABLE_PATH
228+ else :
229+ kernel_lib_file = KERNEL_LIB_PATH
230+
231+ kernel_lib_path = os .path .join (cache_path , kernel_lib_file )
232+
233+ if kernel .execution_backend == "nvrtc" :
234+ # Save cubin and python helper file
235+ src_lib_path = kernel .adapter .libpath
236+ kernel_py_path = os .path .join (cache_path , KERNEL_PY_PATH )
237+ py_src_path = src_lib_path .replace (".cubin" , ".py" )
238+ if verbose :
239+ logger .debug (f"Saving kernel nvrtc python code to file: { kernel_py_path } " )
240+ self ._safe_write_file (kernel_py_path , "wb" ,
241+ lambda f : f .write (self ._load_binary (py_src_path )))
242+ if verbose :
243+ logger .debug (f"Saving kernel library to file: { kernel_lib_path } " )
244+ self ._safe_write_file (kernel_lib_path , "wb" ,
245+ lambda f : f .write (self ._load_binary (src_lib_path )))
246+ elif kernel .execution_backend == "tvm_ffi" :
247+ executable = kernel .adapter .executable
248+ if verbose :
249+ logger .debug (f"Saving kernel executable to file: { kernel_lib_path } " )
250+ self ._safe_write_executable (executable , kernel_lib_path )
251+ else :
252+ src_lib_path = kernel .adapter .libpath
253+ if verbose :
254+ logger .debug (f"Saving kernel library to file: { kernel_lib_path } " )
255+ self ._safe_write_file (kernel_lib_path , "wb" ,
256+ lambda f : f .write (self ._load_binary (src_lib_path )))
257+
192258 except Exception as e :
193259 logger .error (f"Error saving kernel library to disk: { e } " )
194260
@@ -197,8 +263,7 @@ def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: boo
197263 params_path = os .path .join (cache_path , PARAMS_PATH )
198264 if verbose :
199265 logger .debug (f"Saving kernel parameters to disk: { params_path } " )
200- with open (params_path , "wb" ) as f :
201- cloudpickle .dump (kernel .params , f )
266+ self ._safe_write_file (params_path , "wb" , lambda f : cloudpickle .dump (kernel .params , f ))
202267 except Exception as e :
203268 logger .error (f"Error saving kernel parameters to disk: { e } " )
204269
@@ -210,6 +275,7 @@ def _load_kernel_from_disk(
210275 out_idx : list [int ] | int | None = None ,
211276 execution_backend : Literal ["tvm_ffi" , "ctypes" , "cython" , "nvrtc" , "torch" ] = "tvm_ffi" ,
212277 pass_configs : dict = None ,
278+ compile_flags : list [str ] | str | None = None ,
213279 func : Callable = None ,
214280 verbose : bool = False ,
215281 ) -> JITKernel :
@@ -233,41 +299,66 @@ def _load_kernel_from_disk(
233299 if not os .path .exists (cache_path ):
234300 return None
235301
236- kernel_global_source : str | None = None
302+ # Resolve backend to pick correct file names
303+ if execution_backend == "nvrtc" :
304+ kernel_lib_file = KERNEL_CUBIN_PATH
305+ elif execution_backend == "tvm_ffi" :
306+ kernel_lib_file = EXECUTABLE_PATH
307+ else :
308+ kernel_lib_file = KERNEL_LIB_PATH
309+
310+ device_kernel_path = os .path .join (cache_path , DEVICE_KERNEL_PATH )
311+ host_kernel_path = os .path .join (cache_path , HOST_KERNEL_PATH )
312+ kernel_lib_path = os .path .join (cache_path , kernel_lib_file )
313+ params_path = os .path .join (cache_path , PARAMS_PATH )
314+
315+ if not all ([os .path .exists (file ) for file in (kernel_lib_path , params_path )]):
316+ return None
317+
318+ device_kernel_source : str | None = None
319+ host_kernel_source : str | None = None
237320 kernel_params : list [KernelParam ] | None = None
238321
322+ # Load optional device kernel source
239323 try :
240- wrapped_kernel_path = os .path .join (cache_path , WRAPPED_KERNEL_PATH )
241324 if verbose :
242- logger .debug (f"Loading wrapped kernel source code from file: { wrapped_kernel_path } " )
243- with open (wrapped_kernel_path ) as f :
244- kernel_global_source = f .read ()
325+ logger .debug (f"Loading kernel source code from file: { device_kernel_path } " )
326+ with open (device_kernel_path ) as f :
327+ device_kernel_source = f .read ()
245328 except Exception as e :
246- logger .error (f"Error loading wrapped kernel source code from disk: { e } " )
329+ logger .error (f"Error loading kernel source code from disk: { e } " )
247330
248- kernel_lib_path = os .path .join (cache_path , KERNEL_LIB_PATH )
331+ # Load optional host kernel source
332+ try :
333+ if verbose :
334+ logger .debug (f"Loading wrapped kernel source code from file: { host_kernel_path } " )
335+ with open (host_kernel_path ) as f :
336+ host_kernel_source = f .read ()
337+ except Exception as e :
338+ logger .error (f"Error loading host kernel source code from disk: { e } " )
249339
250340 # Load kernel parameters
251341 try :
252- params_path = os .path .join (cache_path , PARAMS_PATH )
253342 if verbose :
254343 logger .debug (f"Loading kernel parameters from file: { params_path } " )
255344 with open (params_path , "rb" ) as f :
256345 kernel_params = cloudpickle .load (f )
257346 except Exception as e :
258347 logger .error (f"Error loading kernel parameters from disk: { e } " )
259348
260- if kernel_global_source and kernel_params :
349+ if host_kernel_source and device_kernel_source and kernel_params :
261350 return JITKernel .from_database (
262351 func = func ,
263- kernel_global_source = kernel_global_source ,
352+ host_kernel_source = host_kernel_source ,
353+ device_kernel_source = device_kernel_source ,
264354 kernel_lib_path = kernel_lib_path ,
265355 params = kernel_params ,
266356 target = target ,
267357 target_host = target_host ,
268358 out_idx = out_idx ,
269359 execution_backend = execution_backend ,
270360 pass_configs = pass_configs ,
361+ compile_flags = compile_flags ,
271362 )
272363 else :
273364 return None
@@ -276,26 +367,29 @@ def save_to_disk(self, path: Path, verbose: bool = False):
276367 if not os .path .exists (path ):
277368 os .makedirs (path )
278369
279- # save best config
370+ # save best config (atomic)
280371 if verbose :
281372 logger .debug (f"Saving best config to file: { path / BEST_CONFIG_PATH } " )
282- with open ( path / BEST_CONFIG_PATH , "w" ) as f :
283- json .dump (self .config , f )
373+ self . _safe_write_file (
374+ str ( path / BEST_CONFIG_PATH ), "w" , lambda f : json .dump (self .config , f ) )
284375
285- # save function
376+ # save function (atomic)
286377 if verbose :
287378 logger .debug (f"Saving function to file: { path / FUNCTION_PATH } " )
288- with open ( path / FUNCTION_PATH , "wb" ) as f :
289- cloudpickle .dump (self .func , f )
379+ self . _safe_write_file (
380+ str ( path / FUNCTION_PATH ), "wb" , lambda f : cloudpickle .dump (self .func , f ) )
290381
291- # save ref latency
382+ # save ref latency (atomic)
292383 if verbose :
293384 logger .debug (f"Saving latency to file: { path / LATENCY_PATH } " )
294- with open (path / LATENCY_PATH , "w" ) as f :
295- json .dump ({
385+ self ._safe_write_file (
386+ str (path / LATENCY_PATH ),
387+ "w" ,
388+ lambda f : json .dump ({
296389 "latency" : self .latency ,
297390 "ref_latency" : self .ref_latency ,
298- }, f )
391+ }, f ),
392+ )
299393
300394 # save kernel
301395 self ._save_kernel_to_disk (path , self .kernel )
@@ -306,6 +400,13 @@ def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> AutotuneResult
306400 return None
307401
308402 verbose = compile_args .verbose
403+ # Normalize target and resolve execution backend for loading
404+ from tilelang .utils .target import determine_target as _determine_target
405+ from tilelang .jit .execution_backend import resolve_execution_backend
406+ norm_target = Target (_determine_target (compile_args .target )) if isinstance (
407+ compile_args .target , str ) else compile_args .target
408+ requested_backend = compile_args .execution_backend
409+ resolved_backend = resolve_execution_backend (requested_backend , norm_target )
309410 # load best config
310411 if verbose :
311412 logger .debug (f"Loading best config from file: { path / BEST_CONFIG_PATH } " )
@@ -325,10 +426,17 @@ def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> AutotuneResult
325426 latency = json .load (f )
326427 latency , ref_latency = latency ["latency" ], latency ["ref_latency" ]
327428
328- kernel = cls ._load_kernel_from_disk (cls , path , compile_args .target ,
329- compile_args .target_host , compile_args .out_idx ,
330- compile_args .execution_backend ,
331- compile_args .pass_configs , func )
429+ kernel = cls ._load_kernel_from_disk (
430+ cls ,
431+ path ,
432+ norm_target ,
433+ compile_args .target_host ,
434+ compile_args .out_idx ,
435+ resolved_backend ,
436+ compile_args .pass_configs ,
437+ None , # compile_flags not tracked here
438+ func ,
439+ )
332440 if kernel is None :
333441 return None
334442 kernel .update_tuner_result (
0 commit comments