diff --git a/python/tvm/relax/backend/cuda/flashinfer.py b/python/tvm/relax/backend/cuda/flashinfer.py index 725fd105add0..8aa4817a302d 100644 --- a/python/tvm/relax/backend/cuda/flashinfer.py +++ b/python/tvm/relax/backend/cuda/flashinfer.py @@ -21,6 +21,8 @@ from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import List +import hashlib +import json import tvm from tvm.target import Target @@ -37,7 +39,57 @@ def _compile_flashinfer_kernels( FLASHINFER_TVM_BINDING_DIR, ) - # Todo(tvm-team): enable compilation cache + # ------------------------------------------------------------------------ + # Caching Flow: create build_directory and compute cache hash. + # ------------------------------------------------------------------------ + build_directory = FLASHINFER_JIT_DIR / name + build_directory.mkdir(parents=True, exist_ok=True) + + def get_object_file_path(src: Path) -> Path: + obj_name = src.stem + ".o" + obj_path = build_directory / obj_name + return obj_path + + # Compute latest modification time among all source files + latest_src_mtime = max(src.stat().st_mtime for src in source_paths) + + # Get modification time for the current file (the one that contains this function) + current_file_mtime = Path(__file__).stat().st_mtime + + # Build the hash key from metadata + hash_key = { + "name": name, + "target": str(target), + "latest_src_mtime": latest_src_mtime, + "current_file_mtime": current_file_mtime, + } + + hash_value = hashlib.md5( + json.dumps(hash_key, sort_keys=True, indent=2).encode("utf-8") + ).hexdigest() + + # Check if a valid hash exists in the build directory + hash_file = build_directory / "hash.md5" + if hash_file.exists(): + with open(hash_file, "r") as f: + cached_hash = f.read().strip() + if cached_hash == hash_value: + # Check that all object files exist + object_files = [] + all_exist = True + for src in source_paths: + obj_path = get_object_file_path(src) + if not obj_path.exists(): + all_exist = False + break + object_files.append(obj_path) + if all_exist: + return object_files + + # If we are here, cache is missing or outdated. Write the new hash and compile the paths + with open(hash_file, "w") as f: + f.write(hash_value) + # ------------------------------------------------------------------------ # 1) Common CUDA compile flags # ------------------------------------------------------------------------ @@ -82,17 +134,12 @@ def _compile_flashinfer_kernels( Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include", ] + CUTLASS_INCLUDE_DIRS - # Where object files will be placed - build_directory = FLASHINFER_JIT_DIR / name - build_directory.mkdir(parents=True, exist_ok=True) - # ------------------------------------------------------------------------ # 3) Function to compile a single source file # ------------------------------------------------------------------------ def compile_single_source(src: Path) -> Path: # Derive the .o filename from the source filename - obj_name = src.stem + ".o" - obj_path = build_directory / obj_name + obj_path = get_object_file_path(src) # Construct the command cmd = ( @@ -202,7 +249,12 @@ def gen_flashinfer_prefill_module( ) jit_args = { "backend": backend, - "uri": "batch_prefill_tvm", + "uri": f"batch_prefill_tvm_dtype_q_{dtype_q}_" + + f"dtype_kv_{dtype_kv}_" + + f"dtype_o_{dtype_o}_" + + f"qk_head_dim_{qk_head_dim}_" + + f"v_head_dim_{v_head_dim}_" + + f"enable_inline_rope_{enable_inline_rope}", "dtype_q": torch_dtype_q, "dtype_kv": torch_dtype_kv, "dtype_o": torch_dtype_o, @@ -273,7 +325,11 @@ def gen_flashinfer_decode_module( torch_dtype_kv = getattr(torch, dtype_kv) torch_dtype_o = getattr(torch, dtype_o) jit_args = { - "uri": "batch_decode_tvm", + "uri": f"batch_decode_tvm_dtype_q_{dtype_q}_" + + f"dtype_kv_{dtype_kv}_" + + f"dtype_o_{dtype_o}_" + + f"qk_head_dim_{qk_head_dim}_" + + f"v_head_dim_{v_head_dim}", "dtype_q": torch_dtype_q, "dtype_kv": torch_dtype_kv, "dtype_o": torch_dtype_o, @@ -343,7 +399,11 @@ def gen_flashinfer_mla_module( torch_dtype_kv = getattr(torch, dtype_kv) torch_dtype_o = getattr(torch, dtype_o) jit_args = { - "uri": "batch_mla_tvm", + "uri": f"batch_mla_tvm_dtype_q_{dtype_q}_" + + f"dtype_kv_{dtype_kv}_" + + f"dtype_o_{dtype_o}_" + + f"head_dim_ckv_{head_dim_ckv}_" + + f"head_dim_kpe_{head_dim_kpe}", "dtype_q": torch_dtype_q, "dtype_kv": torch_dtype_kv, "dtype_o": torch_dtype_o, diff --git a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py index 3e17f6436600..81acf5ee863d 100644 --- a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py +++ b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py @@ -169,7 +169,7 @@ def set_global_func(head_dim, dtype): mod = tvm.IRModule({"main": tir_func}) with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) - f = tvm.compile(mod["main"], target=target) + f = tvm.tir.build(mod["main"], target=target) builts.append(f.entry_func) ( diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py index 41743efeeea2..ffd345229200 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py @@ -155,7 +155,7 @@ def load_module(name: str, static_modules: List[tvm.runtime.Module]): mod = tvm.IRModule({"main": tir_func}) with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) - f = tvm.compile(mod["main"], target=target) + f = tvm.tir.build(mod["main"], target=target) builts.append(f.entry_func) ( diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py index 84b50125ee15..2f726064a71b 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py @@ -168,7 +168,7 @@ def load_module(name: str, static_modules: List[tvm.runtime.Module]): mod = tvm.IRModule({"main": tir_func}) with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) - f = tvm.compile(mod["main"], target=target) + f = tvm.tir.build(mod["main"], target=target) builts.append(f.entry_func) (