Skip to content

Commit f1dc96c

Browse files
committed
finish
1 parent ed15baa commit f1dc96c

File tree

12 files changed

+698
-37
lines changed

12 files changed

+698
-37
lines changed

flashinfer_bench/compile/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def create_pkg_name(sol: Solution, prefix: str = "") -> str:
4242
h.update(src.path.encode())
4343
h.update(src.content.encode())
4444

45-
return prefix + s + "_" + h.hexdigest()[:4]
45+
return prefix + s + "_" + h.hexdigest()[:6]
4646

4747

4848
class BuildError(RuntimeError):
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .cuda_builder import CUDABuilder
22
from .python_builder import PythonBuilder
33
from .triton_builder import TritonBuilder
4+
from .tvm_ffi_builder import TVMFFIBuilder
45

5-
__all__ = ["CUDABuilder", "PythonBuilder", "TritonBuilder"]
6+
__all__ = ["CUDABuilder", "PythonBuilder", "TritonBuilder", "TVMFFIBuilder"]

flashinfer_bench/compile/builders/cuda_builder.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from __future__ import annotations
22

3+
import logging
34
import os
45
import re
56
import shutil
67
import sys
78
from importlib import resources
89
from pathlib import Path
9-
from typing import Dict, List
10+
from typing import Dict, List, Optional
1011

1112
from flashinfer_bench.compile.builder import (
1213
Builder,
@@ -16,21 +17,14 @@
1617
)
1718
from flashinfer_bench.compile.runnable import Runnable
1819
from flashinfer_bench.data import Definition, Solution, SourceFile, SupportedLanguages
20+
from flashinfer_bench.utils import is_cuda_available
1921

2022
CUDA_ALLOWED_EXTS = [".cu", ".cpp", ".cc", ".cxx", ".c"]
2123

24+
logger = logging.getLogger(__name__)
2225

23-
def _verify_cuda() -> bool:
24-
try:
25-
import torch
26-
import torch.utils.cpp_extension
27-
28-
return torch.cuda.is_available()
29-
except ImportError:
30-
return False
3126

32-
33-
def _get_package_paths(pkg_name: str, lib_names: List[str] = None):
27+
def _get_package_paths(pkg_name: str, lib_names: Optional[List[str]] = None):
3428
include_path = None
3529
ldflags = []
3630

@@ -64,8 +58,11 @@ def _get_package_paths(pkg_name: str, lib_names: List[str] = None):
6458
ldflags = [f"/LIBPATH:{lib_path}"] + lib_names
6559

6660
except Exception:
67-
# TODO(shanli): add logger to print warning
68-
pass
61+
logger.warning(
62+
"Failed to discover resources for CUDA package '%s'; continuing without it.",
63+
pkg_name,
64+
exc_info=True,
65+
)
6966

7067
return include_path, ldflags
7168

@@ -125,7 +122,7 @@ class CUDABuilder(Builder):
125122
@classmethod
126123
def _get_cuda_available(cls) -> bool:
127124
if cls._cuda_available is None:
128-
cls._cuda_available = _verify_cuda()
125+
cls._cuda_available = is_cuda_available()
129126
return cls._cuda_available
130127

131128
def __init__(self) -> None:
@@ -142,16 +139,19 @@ def _make_key(self, solution: Solution) -> str:
142139
return f"cuda::{create_pkg_name(solution)}"
143140

144141
def _make_closer(self):
145-
# We keep build dirs for torch extension caching. The temp dirs can be cleaned by calling `clear_cache` on program exit.
142+
# We keep build dirs for torch extension caching. The temp dirs can be cleaned by
143+
# calling `clear_cache` on program exit.
146144
return lambda: None
147145

148146
def _build(self, defn: Definition, sol: Solution) -> Runnable:
149147
# CUDA solutions must provide a C/CUDA symbol as entry point.
150-
# If user prefer a Python wrapper, set language to `python` and ensure compilation and binding are properly handled.
148+
# If user prefer a Python wrapper, set language to `python` and ensure compilation and
149+
# binding are properly handled.
151150
entry_file_extension = "." + sol.spec.entry_point.split("::")[0].split(".")[-1]
152151
if entry_file_extension not in CUDA_ALLOWED_EXTS:
153152
raise BuildError(
154-
f"Entry file type not recognized. Must be one of {CUDA_ALLOWED_EXTS}, got {entry_file_extension}."
153+
f"Entry file type not recognized. Must be one of {CUDA_ALLOWED_EXTS}, "
154+
f"got {entry_file_extension}."
155155
)
156156

157157
if not self._get_cuda_available():
@@ -184,7 +184,8 @@ def _build(self, defn: Definition, sol: Solution) -> Runnable:
184184
inc_path = self._extra_include_paths.get(dep)
185185
if not inc_path:
186186
raise BuildError(
187-
f"{dep} is not available in the current environment but referenced by {sol.name}"
187+
f"{dep} is not available in the current environment but referenced "
188+
f"by {sol.name}"
188189
)
189190
extra_include_paths.append(inc_path)
190191
ldflags = self._extra_ldflags.get(dep)
@@ -218,7 +219,7 @@ def _kw_adapter(**kwargs):
218219
return fn(*args)
219220

220221
meta = {
221-
"definition": defn.name,
222+
-"definition": defn.name,
222223
"solution": sol.name,
223224
"language": "cuda",
224225
"name": name,
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
"""TVM-FFI based builder for CUDA kernels with automatic caching."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
from pathlib import Path
7+
from typing import Any, Dict, List, Tuple
8+
9+
import tvm_ffi
10+
11+
from flashinfer_bench.compile.builder import Builder, BuildError, create_pkg_name
12+
from flashinfer_bench.compile.runnable import Runnable, TVMFFIRunnable
13+
from flashinfer_bench.data import Definition, Solution, SupportedLanguages
14+
from flashinfer_bench.env import get_fib_cache_path
15+
16+
logger = logging.getLogger(__name__)
17+
18+
CUDA_EXTENSIONS = [".cu"]
19+
CPP_EXTENSIONS = [".cpp", ".cc", ".cxx", ".c"]
20+
21+
22+
class TVMFFIBuilder(Builder):
23+
"""Builder using TVM-FFI with automatic caching and multi-process sharing.
24+
25+
Build strategy:
26+
1. Check if .so exists in cache (multi-process safe)
27+
2. If not, compile with tvm_ffi.cpp.build_inline() to cache
28+
3. Load with tvm_ffi.load_module()
29+
30+
Benefits:
31+
- Multi-process benchmark: Only first process compiles, others load from cache
32+
- Cross-framework: Same .so works with PyTorch, JAX, CuPy (DLPack)
33+
- No JIT/AOT distinction: Smart caching handles both cases
34+
"""
35+
36+
def __init__(self) -> None:
37+
super().__init__()
38+
self._extra_include_paths: Dict[str, str] = {}
39+
self._extra_ldflags: Dict[str, List[str]] = {}
40+
41+
def can_build(self, sol: Solution) -> bool:
42+
return sol.spec.language == SupportedLanguages.CUDA
43+
44+
def _make_key(self, solution: Solution) -> str:
45+
return f"tvm_ffi_{create_pkg_name(solution)}"
46+
47+
def _make_closer(self):
48+
return lambda: None
49+
50+
def _get_build_path(self, key: str) -> Path:
51+
return get_fib_cache_path() / "tvm_ffi" / key
52+
53+
def _write_sources(self, path: Path, sol: Solution) -> Tuple[List[str], List[str]]:
54+
"""Extract and write all source files to the given path."""
55+
path.mkdir(parents=True, exist_ok=True)
56+
cpp_files: List[str] = []
57+
cuda_files: List[str] = []
58+
for src in sol.sources:
59+
src_path = path / src.path
60+
if src_path.is_dir():
61+
raise BuildError(f"Source path is a directory: {src_path}")
62+
63+
src_path.write_text(src.content)
64+
65+
if str(src_path).endswith(tuple(CPP_EXTENSIONS)):
66+
cpp_files.append(str(src_path))
67+
elif str(src_path).endswith(tuple(CUDA_EXTENSIONS)):
68+
cuda_files.append(str(src_path))
69+
70+
if len(cpp_files) == 0 and len(cuda_files) == 0:
71+
raise BuildError("No sources found")
72+
return cpp_files, cuda_files
73+
74+
def _get_language(self, cpp_files: List[str], cuda_files: List[str]) -> str:
75+
return "cuda" if len(cuda_files) > 0 else "cpp"
76+
77+
def _get_entry_symbol(self, sol: Solution) -> str:
78+
"""Extract function symbol from entry_point."""
79+
entry_point = sol.spec.entry_point
80+
if "::" not in entry_point:
81+
raise BuildError(
82+
f"Invalid entry_point format: {entry_point}. Expected 'file.cu::symbol'"
83+
)
84+
return entry_point.split("::")[-1]
85+
86+
def _make_runnable(
87+
self, mod: tvm_ffi.Module, entry_symbol: str, defn: Definition, metadata: Dict[str, Any]
88+
) -> Runnable:
89+
"""Create Runnable from TVM-FFI module."""
90+
try:
91+
fn = getattr(mod, entry_symbol)
92+
except AttributeError as e:
93+
raise BuildError(f"Entry point '{entry_symbol}' not found in module") from e
94+
95+
# Create keyword adapter to match definition interface
96+
arg_order = list(defn.inputs.keys()) + list(defn.outputs.keys())
97+
98+
def _kw_adapter(**kwargs):
99+
args = [kwargs[name] for name in arg_order]
100+
return fn(*args)
101+
102+
return TVMFFIRunnable(
103+
fn=_kw_adapter, closer=self._make_closer(), meta=metadata, definition=defn
104+
)
105+
106+
def _build(self, defn: Definition, sol: Solution) -> Runnable:
107+
"""Build with automatic caching - compile once, load from cache afterwards."""
108+
key = self._make_key(sol)
109+
build_path = self._get_build_path(key)
110+
entry_symbol = self._get_entry_symbol(sol)
111+
cpp_files, cuda_files = self._write_sources(build_path, sol)
112+
language = self._get_language(cpp_files, cuda_files)
113+
extra_include_paths = [str(build_path)]
114+
115+
try:
116+
# Use build_inline instead of build to
117+
output_lib_path = tvm_ffi.cpp.build(
118+
name=key,
119+
cpp_files=cpp_files,
120+
cuda_files=cuda_files,
121+
extra_include_paths=extra_include_paths,
122+
build_directory=build_path,
123+
)
124+
except Exception as e:
125+
raise BuildError(f"TVM-FFI compilation failed for '{sol.name}': {e}") from e
126+
127+
# Load the compiled module
128+
try:
129+
mod = tvm_ffi.load_module(output_lib_path)
130+
except Exception as e:
131+
raise BuildError(f"Failed to load compiled module: {e}") from e
132+
133+
metadata = {
134+
"definition": defn.name,
135+
"solution": sol.name,
136+
"language": language,
137+
"binding": "tvm_ffi",
138+
"key": key,
139+
"symbol": entry_symbol,
140+
"binary": output_lib_path,
141+
}
142+
143+
return self._make_runnable(mod, entry_symbol, defn, metadata)

flashinfer_bench/compile/registry.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,13 @@ def build_reference(self, defn: Definition) -> Runnable:
5252
def get_builder_registry() -> BuilderRegistry:
5353
global _registry
5454
if _registry is None:
55-
from .builders import CUDABuilder, PythonBuilder, TritonBuilder
55+
from .builders import CUDABuilder, PythonBuilder, TritonBuilder, TVMFFIBuilder
5656

5757
py = PythonBuilder()
5858
triton = TritonBuilder(py_builder=py)
59-
cuda = CUDABuilder()
59+
tvm_ffi = TVMFFIBuilder()
60+
cuda = CUDABuilder() # Fallback for backward compatibility
6061

61-
_registry = BuilderRegistry((py, triton, cuda))
62+
# Priority: Python > Triton > TVM-FFI > CUDA (pybind11)
63+
_registry = BuilderRegistry((py, triton, tvm_ffi, cuda))
6264
return _registry

flashinfer_bench/compile/runnable.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,21 @@
22

33
from typing import Any, Callable, Dict, Optional
44

5+
from flashinfer_bench.data import Definition
6+
from flashinfer_bench.utils import dtype_str_to_torch_dtype
7+
58

69
class Runnable:
710
def __init__(
8-
self,
9-
fn: Callable[..., Any],
10-
closer: Callable[[], None],
11-
meta: Optional[Dict[str, Any]] = None,
11+
self, fn: Callable[..., Any], closer: Optional[Callable[[], None]], meta: Dict[str, Any]
1212
) -> None:
1313
"""A runnable callable with a required resource closer.
1414
1515
closer: must be provided by the builder and be idempotent.
1616
"""
1717
self._fn = fn
1818
self._closer: Optional[Callable[[], None]] = closer
19-
self.meta: Dict[str, Any] = dict(meta or {})
19+
self.meta: Dict[str, Any] = meta
2020

2121
def __call__(self, **kwargs: Any) -> Any:
2222
"""
@@ -36,3 +36,47 @@ def close(self) -> None:
3636
self._closer()
3737
finally:
3838
self._closer = None
39+
40+
41+
class TVMFFIRunnable(Runnable):
42+
def __init__(
43+
self,
44+
fn: Callable[..., Any],
45+
closer: Optional[Callable[[], None]],
46+
meta: Dict[str, Any],
47+
definition: Definition,
48+
) -> None:
49+
super().__init__(fn, closer, meta)
50+
self._definition = definition
51+
52+
def __call__(self, **kwargs: Any) -> Any:
53+
import torch
54+
55+
# Allocate output tensors first
56+
57+
var_values = self._definition.get_var_values(
58+
{name: list(tensor.shape) for name, tensor in kwargs.items()}
59+
)
60+
output_shapes = self._definition.get_output_shapes(var_values)
61+
output_tensors: Dict[str, torch.Tensor] = {}
62+
device = next(iter(kwargs.values())).device if len(kwargs) > 0 else "cpu"
63+
for name, shape in output_shapes.items():
64+
output_tensors[name] = torch.empty(
65+
shape, dtype=dtype_str_to_torch_dtype(self._definition.outputs[name].dtype)
66+
).to(device)
67+
68+
self.call_dest(**kwargs, **output_tensors)
69+
70+
results = list(output_tensors.values())
71+
if len(results) == 1:
72+
return results[0]
73+
return results
74+
75+
def call_dest(self, **kwargs: Any) -> None:
76+
"""Call the underlying function with destination passing style."""
77+
self._fn(**kwargs)
78+
79+
def close(self) -> None:
80+
if self._closer:
81+
self._closer()
82+
self._closer = None

0 commit comments

Comments
 (0)