Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flashinfer_bench/compile/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def create_pkg_name(sol: Solution, prefix: str = "") -> str:
h.update(src.path.encode())
h.update(src.content.encode())

return prefix + s + "_" + h.hexdigest()[:4]
return prefix + s + "_" + h.hexdigest()[:6]


class BuildError(RuntimeError):
Expand Down
3 changes: 2 additions & 1 deletion flashinfer_bench/compile/builders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .cuda_builder import CUDABuilder
from .python_builder import PythonBuilder
from .triton_builder import TritonBuilder
from .tvm_ffi_builder import TVMFFIBuilder

__all__ = ["CUDABuilder", "PythonBuilder", "TritonBuilder"]
__all__ = ["CUDABuilder", "PythonBuilder", "TritonBuilder", "TVMFFIBuilder"]
39 changes: 20 additions & 19 deletions flashinfer_bench/compile/builders/cuda_builder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

import logging
import os
import re
import shutil
import sys
from importlib import resources
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Optional

from flashinfer_bench.compile.builder import (
Builder,
Expand All @@ -16,21 +17,14 @@
)
from flashinfer_bench.compile.runnable import Runnable
from flashinfer_bench.data import Definition, Solution, SourceFile, SupportedLanguages
from flashinfer_bench.utils import is_cuda_available

CUDA_ALLOWED_EXTS = [".cu", ".cpp", ".cc", ".cxx", ".c"]

logger = logging.getLogger(__name__)

def _verify_cuda() -> bool:
try:
import torch
import torch.utils.cpp_extension

return torch.cuda.is_available()
except ImportError:
return False


def _get_package_paths(pkg_name: str, lib_names: List[str] = None):
def _get_package_paths(pkg_name: str, lib_names: Optional[List[str]] = None):
include_path = None
ldflags = []

Expand Down Expand Up @@ -64,8 +58,11 @@ def _get_package_paths(pkg_name: str, lib_names: List[str] = None):
ldflags = [f"/LIBPATH:{lib_path}"] + lib_names

except Exception:
# TODO(shanli): add logger to print warning
pass
logger.warning(
"Failed to discover resources for CUDA package '%s'; continuing without it.",
pkg_name,
exc_info=True,
)

return include_path, ldflags

Expand Down Expand Up @@ -125,7 +122,7 @@ class CUDABuilder(Builder):
@classmethod
def _get_cuda_available(cls) -> bool:
if cls._cuda_available is None:
cls._cuda_available = _verify_cuda()
cls._cuda_available = is_cuda_available()
return cls._cuda_available

def __init__(self) -> None:
Expand All @@ -142,16 +139,19 @@ def _make_key(self, solution: Solution) -> str:
return f"cuda::{create_pkg_name(solution)}"

def _make_closer(self):
# We keep build dirs for torch extension caching. The temp dirs can be cleaned by calling `clear_cache` on program exit.
# We keep build dirs for torch extension caching. The temp dirs can be cleaned by
# calling `clear_cache` on program exit.
return lambda: None

def _build(self, defn: Definition, sol: Solution) -> Runnable:
# CUDA solutions must provide a C/CUDA symbol as entry point.
# If user prefer a Python wrapper, set language to `python` and ensure compilation and binding are properly handled.
# If user prefer a Python wrapper, set language to `python` and ensure compilation and
# binding are properly handled.
entry_file_extension = "." + sol.spec.entry_point.split("::")[0].split(".")[-1]
if entry_file_extension not in CUDA_ALLOWED_EXTS:
raise BuildError(
f"Entry file type not recognized. Must be one of {CUDA_ALLOWED_EXTS}, got {entry_file_extension}."
f"Entry file type not recognized. Must be one of {CUDA_ALLOWED_EXTS}, "
f"got {entry_file_extension}."
)

if not self._get_cuda_available():
Expand Down Expand Up @@ -184,7 +184,8 @@ def _build(self, defn: Definition, sol: Solution) -> Runnable:
inc_path = self._extra_include_paths.get(dep)
if not inc_path:
raise BuildError(
f"{dep} is not available in the current environment but referenced by {sol.name}"
f"{dep} is not available in the current environment but referenced "
f"by {sol.name}"
)
extra_include_paths.append(inc_path)
ldflags = self._extra_ldflags.get(dep)
Expand Down Expand Up @@ -218,7 +219,7 @@ def _kw_adapter(**kwargs):
return fn(*args)

meta = {
"definition": defn.name,
-"definition": defn.name,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Syntax error in meta dictionary.

Line 222 contains invalid Python syntax with a leading - character:

-"definition": defn.name,

If this is intended to remove the "definition" field from the metadata, the line should be deleted entirely. If it's a diff artifact, it must be corrected before merge.

Apply this fix to remove the syntax error:

     meta = {
-        -"definition": defn.name,
         "solution": sol.name,
         "language": "cuda",
🤖 Prompt for AI Agents
In flashinfer_bench/compile/builders/cuda_builder.py around line 222, the meta
dictionary contains an invalid line starting with a stray '-' character (
-"definition": defn.name, ). Remove this line entirely (or replace it with a
valid key/value if the intention was to keep it) so the dictionary uses valid
Python syntax and no diff artifact remains.

"solution": sol.name,
"language": "cuda",
"name": name,
Expand Down
143 changes: 143 additions & 0 deletions flashinfer_bench/compile/builders/tvm_ffi_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""TVM-FFI based builder for CUDA kernels with automatic caching."""

from __future__ import annotations

import logging
from pathlib import Path
from typing import Any, Dict, List, Tuple

import tvm_ffi

from flashinfer_bench.compile.builder import Builder, BuildError, create_pkg_name
from flashinfer_bench.compile.runnable import Runnable, TVMFFIRunnable
from flashinfer_bench.data import Definition, Solution, SupportedLanguages
from flashinfer_bench.env import get_fib_cache_path

logger = logging.getLogger(__name__)

CUDA_EXTENSIONS = [".cu"]
CPP_EXTENSIONS = [".cpp", ".cc", ".cxx", ".c"]


class TVMFFIBuilder(Builder):
"""Builder using TVM-FFI with automatic caching and multi-process sharing.
Build strategy:
1. Check if .so exists in cache (multi-process safe)
2. If not, compile with tvm_ffi.cpp.build_inline() to cache
3. Load with tvm_ffi.load_module()
Benefits:
- Multi-process benchmark: Only first process compiles, others load from cache
- Cross-framework: Same .so works with PyTorch, JAX, CuPy (DLPack)
- No JIT/AOT distinction: Smart caching handles both cases
"""
Comment on lines +22 to +34
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Update docstring to match actual API usage.

The docstring (line 27) mentions tvm_ffi.cpp.build_inline(), but the implementation uses tvm_ffi.cpp.build() (line 117). This inconsistency can confuse maintainers.

Update the docstring:

     Build strategy:
     1. Check if .so exists in cache (multi-process safe)
-    2. If not, compile with tvm_ffi.cpp.build_inline() to cache
+    2. If not, compile with tvm_ffi.cpp.build() to cache
     3. Load with tvm_ffi.load_module()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
class TVMFFIBuilder(Builder):
"""Builder using TVM-FFI with automatic caching and multi-process sharing.
Build strategy:
1. Check if .so exists in cache (multi-process safe)
2. If not, compile with tvm_ffi.cpp.build_inline() to cache
3. Load with tvm_ffi.load_module()
Benefits:
- Multi-process benchmark: Only first process compiles, others load from cache
- Cross-framework: Same .so works with PyTorch, JAX, CuPy (DLPack)
- No JIT/AOT distinction: Smart caching handles both cases
"""
class TVMFFIBuilder(Builder):
"""Builder using TVM-FFI with automatic caching and multi-process sharing.
Build strategy:
1. Check if .so exists in cache (multi-process safe)
2. If not, compile with tvm_ffi.cpp.build() to cache
3. Load with tvm_ffi.load_module()
Benefits:
- Multi-process benchmark: Only first process compiles, others load from cache
- Cross-framework: Same .so works with PyTorch, JAX, CuPy (DLPack)
- No JIT/AOT distinction: Smart caching handles both cases
"""
🤖 Prompt for AI Agents
In flashinfer_bench/compile/builders/tvm_ffi_builder.py around lines 22 to 34,
the docstring incorrectly references tvm_ffi.cpp.build_inline() while the
implementation uses tvm_ffi.cpp.build(); update the docstring to mention
tvm_ffi.cpp.build() (or a neutral description like "tvm_ffi.cpp.build()" or
"tvm_ffi.cpp.build() / build API") so it matches the actual API used, keeping
the rest of the build strategy and benefits text intact.


def __init__(self) -> None:
super().__init__()
self._extra_include_paths: Dict[str, str] = {}
self._extra_ldflags: Dict[str, List[str]] = {}

def can_build(self, sol: Solution) -> bool:
return sol.spec.language == SupportedLanguages.CUDA

def _make_key(self, solution: Solution) -> str:
return f"tvm_ffi_{create_pkg_name(solution)}"

def _make_closer(self):
return lambda: None

def _get_build_path(self, key: str) -> Path:
return get_fib_cache_path() / "tvm_ffi" / key

def _write_sources(self, path: Path, sol: Solution) -> Tuple[List[str], List[str]]:
"""Extract and write all source files to the given path."""
path.mkdir(parents=True, exist_ok=True)
cpp_files: List[str] = []
cuda_files: List[str] = []
for src in sol.sources:
src_path = path / src.path
if src_path.is_dir():
raise BuildError(f"Source path is a directory: {src_path}")

src_path.write_text(src.content)

if str(src_path).endswith(tuple(CPP_EXTENSIONS)):
cpp_files.append(str(src_path))
elif str(src_path).endswith(tuple(CUDA_EXTENSIONS)):
cuda_files.append(str(src_path))

if len(cpp_files) == 0 and len(cuda_files) == 0:
raise BuildError("No sources found")
return cpp_files, cuda_files

def _get_language(self, cpp_files: List[str], cuda_files: List[str]) -> str:
return "cuda" if len(cuda_files) > 0 else "cpp"

def _get_entry_symbol(self, sol: Solution) -> str:
"""Extract function symbol from entry_point."""
entry_point = sol.spec.entry_point
if "::" not in entry_point:
raise BuildError(
f"Invalid entry_point format: {entry_point}. Expected 'file.cu::symbol'"
)
return entry_point.split("::")[-1]

def _make_runnable(
self, mod: tvm_ffi.Module, entry_symbol: str, defn: Definition, metadata: Dict[str, Any]
) -> Runnable:
"""Create Runnable from TVM-FFI module."""
try:
fn = getattr(mod, entry_symbol)
except AttributeError as e:
raise BuildError(f"Entry point '{entry_symbol}' not found in module") from e

# Create keyword adapter to match definition interface
arg_order = list(defn.inputs.keys()) + list(defn.outputs.keys())

def _kw_adapter(**kwargs):
args = [kwargs[name] for name in arg_order]
return fn(*args)

return TVMFFIRunnable(
fn=_kw_adapter, closer=self._make_closer(), meta=metadata, definition=defn
)

def _build(self, defn: Definition, sol: Solution) -> Runnable:
"""Build with automatic caching - compile once, load from cache afterwards."""
key = self._make_key(sol)
build_path = self._get_build_path(key)
entry_symbol = self._get_entry_symbol(sol)
cpp_files, cuda_files = self._write_sources(build_path, sol)
language = self._get_language(cpp_files, cuda_files)
extra_include_paths = [str(build_path)]

try:
# Use build_inline instead of build to
output_lib_path = tvm_ffi.cpp.build(
name=key,
cpp_files=cpp_files,
cuda_files=cuda_files,
extra_include_paths=extra_include_paths,
build_directory=build_path,
)
Comment on lines +115 to +123
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix incomplete comment.

Line 116 contains an incomplete comment: "Use build_inline instead of build to". Either complete this comment or remove it to avoid confusion.

         try:
-            # Use build_inline instead of build to
+            # Build the TVM-FFI module
             output_lib_path = tvm_ffi.cpp.build(
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
# Use build_inline instead of build to
output_lib_path = tvm_ffi.cpp.build(
name=key,
cpp_files=cpp_files,
cuda_files=cuda_files,
extra_include_paths=extra_include_paths,
build_directory=build_path,
)
try:
# Build the TVM-FFI module
output_lib_path = tvm_ffi.cpp.build(
name=key,
cpp_files=cpp_files,
cuda_files=cuda_files,
extra_include_paths=extra_include_paths,
build_directory=build_path,
)
🤖 Prompt for AI Agents
In flashinfer_bench/compile/builders/tvm_ffi_builder.py around lines 115 to 123,
the inline comment "Use build_inline instead of build to" is incomplete and
confusing; update the comment to either finish the explanation (e.g., "Use
build_inline instead of build to embed sources and avoid temporary files during
compilation" or the actual rationale) or remove the comment entirely. Ensure the
final comment clearly states the reason for choosing build_inline (or is
removed) and matches the surrounding code style and tone.

except Exception as e:
raise BuildError(f"TVM-FFI compilation failed for '{sol.name}': {e}") from e

# Load the compiled module
try:
mod = tvm_ffi.load_module(output_lib_path)
except Exception as e:
raise BuildError(f"Failed to load compiled module: {e}") from e

metadata = {
"definition": defn.name,
"solution": sol.name,
"language": language,
"binding": "tvm_ffi",
"key": key,
"symbol": entry_symbol,
"binary": output_lib_path,
}

return self._make_runnable(mod, entry_symbol, defn, metadata)
8 changes: 5 additions & 3 deletions flashinfer_bench/compile/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@ def build_reference(self, defn: Definition) -> Runnable:
def get_builder_registry() -> BuilderRegistry:
global _registry
if _registry is None:
from .builders import CUDABuilder, PythonBuilder, TritonBuilder
from .builders import CUDABuilder, PythonBuilder, TritonBuilder, TVMFFIBuilder

py = PythonBuilder()
triton = TritonBuilder(py_builder=py)
cuda = CUDABuilder()
tvm_ffi = TVMFFIBuilder()
cuda = CUDABuilder() # Fallback for backward compatibility

_registry = BuilderRegistry((py, triton, cuda))
# Priority: Python > Triton > TVM-FFI > CUDA (pybind11)
_registry = BuilderRegistry((py, triton, tvm_ffi, cuda))
return _registry
54 changes: 49 additions & 5 deletions flashinfer_bench/compile/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@

from typing import Any, Callable, Dict, Optional

from flashinfer_bench.data import Definition
from flashinfer_bench.utils import dtype_str_to_torch_dtype


class Runnable:
def __init__(
self,
fn: Callable[..., Any],
closer: Callable[[], None],
meta: Optional[Dict[str, Any]] = None,
self, fn: Callable[..., Any], closer: Optional[Callable[[], None]], meta: Dict[str, Any]
) -> None:
"""A runnable callable with a required resource closer.

closer: must be provided by the builder and be idempotent.
"""
self._fn = fn
self._closer: Optional[Callable[[], None]] = closer
self.meta: Dict[str, Any] = dict(meta or {})
self.meta: Dict[str, Any] = meta

def __call__(self, **kwargs: Any) -> Any:
"""
Expand All @@ -36,3 +36,47 @@ def close(self) -> None:
self._closer()
finally:
self._closer = None


class TVMFFIRunnable(Runnable):
def __init__(
self,
fn: Callable[..., Any],
closer: Optional[Callable[[], None]],
meta: Dict[str, Any],
definition: Definition,
) -> None:
super().__init__(fn, closer, meta)
self._definition = definition

def __call__(self, **kwargs: Any) -> Any:
import torch

# Allocate output tensors first

var_values = self._definition.get_var_values(
{name: list(tensor.shape) for name, tensor in kwargs.items()}
)
output_shapes = self._definition.get_output_shapes(var_values)
output_tensors: Dict[str, torch.Tensor] = {}
device = next(iter(kwargs.values())).device if len(kwargs) > 0 else "cpu"
for name, shape in output_shapes.items():
output_tensors[name] = torch.empty(
shape, dtype=dtype_str_to_torch_dtype(self._definition.outputs[name].dtype)
).to(device)

self.call_dest(**kwargs, **output_tensors)

results = list(output_tensors.values())
if len(results) == 1:
return results[0]
return results

def call_dest(self, **kwargs: Any) -> None:
"""Call the underlying function with destination passing style."""
self._fn(**kwargs)

def close(self) -> None:
if self._closer:
self._closer()
self._closer = None
Loading
Loading