-
Notifications
You must be signed in to change notification settings - Fork 12
[tvm-ffi] TVMFFIBuilder and PrebuiltLibraryManager #111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughThe PR introduces TVM-FFI-based CUDA kernel building with automatic caching and multi-process safety. It adds a PrebuiltLibraryManager for managing prebuilt libraries, a TVMFFIBuilder class integrating with TVM-FFI, CUDA dependency discovery, comprehensive test coverage, and an example demonstrating cross-framework kernel execution. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant TVMFFIBuilder
participant PrebuiltLibraryManager
participant tvm_ffi
participant Cache as Disk Cache
User->>TVMFFIBuilder: build(solution, definition)
TVMFFIBuilder->>TVMFFIBuilder: can_build() check
TVMFFIBuilder->>TVMFFIBuilder: _extract_cuda_sources()
TVMFFIBuilder->>TVMFFIBuilder: _make_key() generate cache key
TVMFFIBuilder->>PrebuiltLibraryManager: find(cache_key)
alt Library cached
PrebuiltLibraryManager->>Cache: lookup
Cache-->>PrebuiltLibraryManager: return path
PrebuiltLibraryManager-->>TVMFFIBuilder: return lib_path
else Library not cached
PrebuiltLibraryManager-->>TVMFFIBuilder: None
TVMFFIBuilder->>TVMFFIBuilder: _collect_dependencies()
TVMFFIBuilder->>tvm_ffi: cpp.build_inline(sources, ...)
tvm_ffi-->>TVMFFIBuilder: compiled module
TVMFFIBuilder->>Cache: save module
end
TVMFFIBuilder->>TVMFFIBuilder: _make_runnable()
TVMFFIBuilder-->>User: return Runnable
User->>User: execute kernel via Runnable
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25-35 minutes
Areas requiring extra attention:
Poem
Pre-merge checks and finishing touches✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @Ubospica, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates TVM-FFI into the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR introduces TVM-FFI integration through TVMFFIBuilder and PrebuiltLibraryManager, which is a great addition for kernel compilation and caching. The implementation is solid, with good test coverage and a helpful example. I've found a critical bug in argument handling for the compiled kernels and a high-severity issue in the prebuilt library path logic that could lead to incorrect behavior. I've also included a few medium-severity suggestions to improve an example's efficiency, enhance error handling, and add a test case for the path logic bug. Overall, great work on this feature.
| raise BuildError(f"Symbol '{symbol}' not found in module") from e | ||
|
|
||
| # Create keyword adapter to match definition interface | ||
| arg_order = list(defn.inputs.keys()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The arg_order is constructed using only defn.inputs. This will lead to incorrect arguments being passed to the native function, as it omits outputs and other scalar arguments (axes). Based on the example provided, the arguments should include inputs, outputs, and axes values.
| arg_order = list(defn.inputs.keys()) | |
| arg_order = list(defn.inputs.keys()) + list(defn.outputs.keys()) + list(defn.axes.keys()) |
| def get_cache_dir(self) -> str: | ||
| """Get the cache directory for compiling new libraries. | ||
|
|
||
| Returns the last path in search_paths, which is always the local cache. | ||
| """ | ||
| cache_dir = self._search_paths[-1] # Always the local cache | ||
| os.makedirs(cache_dir, exist_ok=True) | ||
| return cache_dir |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method incorrectly assumes the cache directory is always the last path in self._search_paths. This assumption fails if extra_paths are provided and the cache directory is also specified with higher priority (e.g., via FIB_PREBUILT_PATH). This can lead to compiled libraries being written to the wrong directory.
To fix this, the cache directory should be explicitly managed instead of being inferred from the search path order. A more robust approach would be:
- Store the cache directory path in an instance variable in
__init__. _build_search_pathsshould ensure this cache path is included in the search paths.get_cache_dirshould then simply return the stored instance variable and ensure it exists.
| a_jax = jnp.array(a_torch.cpu().numpy()) | ||
| b_jax = jnp.array(b_torch.cpu().numpy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current method of creating JAX tensors from PyTorch tensors is inefficient as it involves a gpu -> cpu -> gpu roundtrip via .cpu().numpy(). For this example, it's clearer and more efficient to generate random JAX tensors directly on the device, similar to the CuPy example. This makes the test case for each framework independent and avoids the performance overhead of data transfer.
Note: You'll need to add import jax at the beginning of the try block for this suggestion to work.
| a_jax = jnp.array(a_torch.cpu().numpy()) | |
| b_jax = jnp.array(b_torch.cpu().numpy()) | |
| key = jax.random.PRNGKey(0) | |
| a_jax, b_jax = jax.random.normal(key, (2, n), dtype=jnp.float32) |
| except Exception: | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Catching a broad Exception and silently passing with pass can hide important errors. If an error occurs here, dependency resolution might fail silently, leading to compilation errors later that are hard to debug. It would be better to catch more specific exceptions (e.g., ModuleNotFoundError if a package for resources.files doesn't exist) or at least log a warning that a dependency path could not be resolved.
| def test_get_cache_dir(self): | ||
| """Test getting cache directory.""" | ||
| manager = PrebuiltLibraryManager() | ||
| cache_dir = manager.get_cache_dir() | ||
|
|
||
| # Should exist and be writable | ||
| assert os.path.exists(cache_dir) | ||
| assert os.path.isdir(cache_dir) | ||
| assert os.access(cache_dir, os.W_OK) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test for get_cache_dir is good, but it doesn't cover a potential bug where an incorrect cache directory could be returned. This can happen if the cache path is also specified as a higher-priority path (e.g., via FIB_PREBUILT_PATH) and extra_paths are also provided.
Please consider adding a test case to cover this scenario, which would fail with the current implementation and pass after fixing the issue in PrebuiltLibraryManager. Here is a suggested test:
def test_get_cache_dir_priority(self, monkeypatch):
"""Test get_cache_dir returns correct path even with higher priority overrides."""
from flashinfer_bench.env import get_fib_cache_path
import os
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
extra_dir = os.path.join(tmpdir, "extra")
os.makedirs(extra_dir)
cache_dir = os.path.join(get_fib_cache_path(), "tvm_ffi")
monkeypatch.setenv("FIB_PREBUILT_PATH", cache_dir)
manager = PrebuiltLibraryManager(extra_paths=[extra_dir])
assert manager.get_cache_dir() == cache_dirThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # Create keyword adapter to match definition interface | ||
| arg_order = list(defn.inputs.keys()) | ||
|
|
||
| def _kw_adapter(**kwargs): | ||
| args = [kwargs[name] for name in arg_order] | ||
| return fn(*args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop output and scalar arguments when invoking TVM FFI kernels
The adapter returned from _make_runnable only forwards values for defn.inputs (arg_order = list(defn.inputs.keys())). Any keyword such as outputs or axis/scalar parameters are ignored, so calls like runnable(a=a_torch, b=b_torch, c=c_torch, n=n) are translated to vector_add(a, b) and raise a TypeError or write to uninitialised memory. Kernel definitions always contain outputs and usually dimension arguments, so the builder cannot execute compiled libraries. The adapter should pass through all required kwargs (inputs, outputs, scalars) or match the entry-point signature explicitly.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (13)
flashinfer_bench/compile/registry.py (1)
55-64: Builder ordering LGTM; add a tiny trace to see which backend was chosen at runtime.Priority Python > Triton > TVM‑FFI > CUDA is sensible. Consider logging the chosen builder once in BuilderRegistry.build for observability during benchmarks.
flashinfer_bench/compile/builders/__init__.py (1)
4-6: Sort all to satisfy Ruff RUF022.Apply isort-style ordering.
-from .tvm_ffi_builder import TVMFFIBuilder - -__all__ = ["CUDABuilder", "PythonBuilder", "TritonBuilder", "TVMFFIBuilder"] +from .tvm_ffi_builder import TVMFFIBuilder + +__all__ = ["CUDABuilder", "PythonBuilder", "TritonBuilder", "TVMFFIBuilder"]Note: if you want strict lint compliance, change to alphabetical order:
-__all__ = ["CUDABuilder", "PythonBuilder", "TritonBuilder", "TVMFFIBuilder"] +__all__ = ["CUDABuilder", "PythonBuilder", "TVMFFIBuilder", "TritonBuilder"]As per static analysis hints.
flashinfer_bench/compile/__init__.py (1)
7-7: New exports are fine; keep all sorted to appease Ruff.Sort entries alphabetically to pass RUF022.
-__all__ = [ - "Builder", - "BuildError", - "BuilderRegistry", - "Runnable", - "get_builder_registry", - "PrebuiltLibraryManager", - "get_prebuilt_manager", -] +__all__ = [ + "Builder", + "BuilderRegistry", + "BuildError", + "PrebuiltLibraryManager", + "Runnable", + "get_builder_registry", + "get_prebuilt_manager", +]examples/tvm_ffi_example.py (2)
84-101: JAX section likely runs on CPU arrays; ensure GPU device or skip.For a CUDA kernel, place arrays on a GPU device before calling, or skip if no GPU is present to avoid runtime errors.
- import jax.numpy as jnp + import jax, jax.numpy as jnp @@ - a_jax = jnp.array(a_torch.cpu().numpy()) - b_jax = jnp.array(b_torch.cpu().numpy()) - c_jax = jnp.empty((n,), dtype=jnp.float32) + gpus = jax.devices("gpu") + if not gpus: + raise ImportError("No JAX GPU device") + dev = gpus[0] + a_jax = jax.device_put(jnp.array(a_torch.cpu().numpy()), dev) + b_jax = jax.device_put(jnp.array(b_torch.cpu().numpy()), dev) + c_jax = jax.device_put(jnp.empty((n,), dtype=jnp.float32), dev)
9-13: Remove unused import.flashinfer_bench as fib is unused.
-import flashinfer_bench as fib from flashinfer_bench.compile import get_builder_registryflashinfer_bench/compile/prebuilt.py (3)
15-21: Docstring search order outdated.The implementation also supports extra_paths and appends the local cache as step 4. Update the docstring to reflect the actual order.
103-110: Edge case: ensure get_cache_dir always returns the actual cache path even if it appeared earlier.If an identical cache path is already in search_paths (e.g., via env/extra_paths), the “last element” assumption may break. Compute and return the canonical cache path directly.
- cache_dir = self._search_paths[-1] # Always the local cache - os.makedirs(cache_dir, exist_ok=True) - return cache_dir + cache_dir = os.path.join(get_fib_cache_path(), "tvm_ffi") + os.makedirs(cache_dir, exist_ok=True) + return cache_dir
88-101: Cross‑platform: only searches for .so.Support .dylib (macOS) and .dll (Windows) to make prebuilt discovery portable.
- # Try both with and without .so extension - if not lib_name.endswith(".so"): - filename = f"{lib_name}.so" - else: - filename = lib_name - - for search_path in self._search_paths: - lib_path = os.path.join(search_path, filename) - if os.path.exists(lib_path): - logger.debug(f"Found prebuilt library: {lib_path}") - return lib_path + # Try common extensions across platforms + candidates = [lib_name] if any(lib_name.endswith(ext) for ext in (".so", ".dylib", ".dll")) \ + else [f"{lib_name}{ext}" for ext in (".so", ".dylib", ".dll")] + for search_path in self._search_paths: + for fname in candidates: + lib_path = os.path.join(search_path, fname) + if os.path.exists(lib_path): + logger.debug(f"Found prebuilt library: {lib_path}") + return lib_pathtests/compile/test_prebuilt.py (1)
79-93: Make tests portable across OS library suffixes.Hardcoding .so breaks on Windows/macOS. Parameterize the extension or probe via multiple suffixes.
- lib_path = os.path.join(tmpdir, "test_lib.so") + import sys + ext = ".dll" if sys.platform.startswith("win") else (".dylib" if sys.platform == "darwin" else ".so") + lib_path = os.path.join(tmpdir, f"test_lib{ext}") @@ - found2 = manager.find("test_lib.so") + found2 = manager.find(f"test_lib{ext}")tests/compile/test_tvm_ffi_builder.py (2)
178-179: Escape regex metacharacters in match=.Use a raw string or escape dots so the message matches literally.
-with pytest.raises(BuildError, match="No .cu CUDA sources"): +with pytest.raises(BuildError, match=r"No \.cu CUDA sources"):As per static analysis hint (RUF043).
114-122: Avoid relying on private method in test (builder._get_lib_path).Accessing a private method couples tests to internals. Prefer asserting on the returned path from build (e.g., via runnable.meta["cache_dir"] and expected filename) or expose a small public helper.
- lib_path = builder._get_lib_path(simple_solution) + cache_dir = builder._prebuilt_manager.get_cache_dir() + # Derive expected filename from solution name if needed, or assert any *.so exists under cache_dir. + lib_path = next((p for p in (os.path.join(cache_dir, f) for f in os.listdir(cache_dir)) if p.endswith(".so")), None)flashinfer_bench/compile/builders/tvm_ffi_builder.py (2)
49-64: Use unpacking for list concatenation.On line 64, prefer unpacking over concatenation for better performance and readability.
Apply this diff:
- elif sys.platform == "win32": - ldflags = [f"/LIBPATH:{lib_path}"] + lib_names + elif sys.platform == "win32": + ldflags = [f"/LIBPATH:{lib_path}", *lib_names]
66-69: Consider logging exceptions during package path discovery.The bare
except Exception: passsilently swallows all errors during package path discovery, making it difficult to diagnose issues when dependencies are misconfigured or packages are malformed.Apply this diff:
except Exception: - pass + logger.debug(f"Failed to discover package paths for {pkg_name}") return include_path, ldflags
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
examples/tvm_ffi_example.py(1 hunks)flashinfer_bench/compile/__init__.py(1 hunks)flashinfer_bench/compile/builders/__init__.py(1 hunks)flashinfer_bench/compile/builders/tvm_ffi_builder.py(1 hunks)flashinfer_bench/compile/prebuilt.py(1 hunks)flashinfer_bench/compile/registry.py(1 hunks)pyproject.toml(1 hunks)tests/compile/test_prebuilt.py(1 hunks)tests/compile/test_tvm_ffi_builder.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (8)
flashinfer_bench/compile/registry.py (3)
flashinfer_bench/compile/builders/cuda_builder.py (1)
CUDABuilder(122-238)flashinfer_bench/compile/builders/python_builder.py (1)
PythonBuilder(20-104)flashinfer_bench/compile/builders/triton_builder.py (1)
TritonBuilder(18-51)
tests/compile/test_prebuilt.py (1)
flashinfer_bench/compile/prebuilt.py (5)
PrebuiltLibraryManager(14-110)get_prebuilt_manager(113-131)search_paths(71-73)find(75-101)get_cache_dir(103-110)
flashinfer_bench/compile/builders/__init__.py (1)
flashinfer_bench/compile/builders/tvm_ffi_builder.py (1)
TVMFFIBuilder(123-287)
flashinfer_bench/compile/__init__.py (1)
flashinfer_bench/compile/prebuilt.py (2)
PrebuiltLibraryManager(14-110)get_prebuilt_manager(113-131)
examples/tvm_ffi_example.py (2)
flashinfer_bench/compile/registry.py (2)
get_builder_registry(52-64)build(26-31)flashinfer_bench/data/solution.py (3)
BuildSpec(60-91)SourceFile(27-57)SupportedLanguages(12-24)
flashinfer_bench/compile/builders/tvm_ffi_builder.py (5)
flashinfer_bench/compile/builder.py (3)
Builder(52-95)BuildError(48-49)create_pkg_name(33-45)flashinfer_bench/compile/prebuilt.py (3)
PrebuiltLibraryManager(14-110)get_cache_dir(103-110)find(75-101)flashinfer_bench/compile/runnable.py (1)
Runnable(6-38)flashinfer_bench/data/solution.py (2)
SourceFile(27-57)SupportedLanguages(12-24)flashinfer_bench/logging.py (1)
get_logger(9-12)
flashinfer_bench/compile/prebuilt.py (2)
flashinfer_bench/env.py (1)
get_fib_cache_path(46-57)flashinfer_bench/logging.py (1)
get_logger(9-12)
tests/compile/test_tvm_ffi_builder.py (4)
flashinfer_bench/compile/builders/tvm_ffi_builder.py (4)
TVMFFIBuilder(123-287)_verify_tvm_ffi(23-31)can_build(155-156)_get_lib_path(186-189)flashinfer_bench/data/solution.py (3)
BuildSpec(60-91)SourceFile(27-57)SupportedLanguages(12-24)flashinfer_bench/compile/registry.py (1)
build(26-31)flashinfer_bench/compile/builder.py (1)
BuildError(48-49)
🪛 GitHub Actions: .github/workflows/linting.yaml
flashinfer_bench/compile/builders/tvm_ffi_builder.py
[error] 27-27: F401 'tvm_ffi.cpp' imported but unused. Remove unused import to satisfy linter.
🪛 Ruff (0.14.3)
flashinfer_bench/compile/builders/__init__.py
6-6: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
flashinfer_bench/compile/__init__.py
11-19: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
flashinfer_bench/compile/builders/tvm_ffi_builder.py
29-29: Consider moving this statement to an else block
(TRY300)
64-64: Consider [f"/LIBPATH:{lib_path}", *lib_names] instead of concatenation
Replace with [f"/LIBPATH:{lib_path}", *lib_names]
(RUF005)
66-67: try-except-pass detected, consider logging the exception
(S110)
66-66: Do not catch blind exception: Exception
(BLE001)
169-169: Avoid specifying long messages outside the exception class
(TRY003)
173-173: Avoid specifying long messages outside the exception class
(TRY003)
181-183: Avoid specifying long messages outside the exception class
(TRY003)
200-203: Avoid specifying long messages outside the exception class
(TRY003)
215-215: Avoid specifying long messages outside the exception class
(TRY003)
230-230: Do not catch blind exception: Exception
(BLE001)
254-254: Avoid specifying long messages outside the exception class
(TRY003)
260-260: Avoid specifying long messages outside the exception class
(TRY003)
269-269: Avoid specifying long messages outside the exception class
(TRY003)
tests/compile/test_tvm_ffi_builder.py
178-178: Pattern passed to match= contains metacharacters but is neither escaped nor raw
(RUF043)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.12
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.13
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.10
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.11
🔇 Additional comments (9)
pyproject.toml (1)
26-32: Wheel availability confirmed for specified targets—review comment is correct.The verification shows wheels exist for your CI targets:
- Linux/x86_64: cp310, cp311, cp312-abi3 (covers 3.13 via stable ABI)
- Python 3.10–3.13: All versions covered via direct wheels (cp310, cp311) and abi3 wheels (cp312+)
The package name and import path are correct per PyPI documentation. No action needed.
flashinfer_bench/compile/builders/tvm_ffi_builder.py (8)
1-21: LGTM!The imports and constants are well-structured and appropriate for the TVM-FFI builder implementation.
72-86: LGTM!The dependency configuration and detection patterns are well-structured and comprehensive.
89-96: LGTM!The dependency discovery logic correctly populates include paths and linker flags.
99-120: LGTM!The dependency checking logic is thorough, with appropriate comment stripping to avoid false positives and an early optimization check.
123-156: LGTM!The class initialization and availability checks are well-designed, with appropriate caching of the TVM-FFI availability check and proper integration with the PrebuiltLibraryManager.
158-189: LGTM!The helper methods are well-structured with clear error handling. The error messages provide good context for debugging.
191-210: LGTM!The dependency collection logic correctly validates that required dependencies are available and provides clear error messages when they're missing.
212-287: LGTM!The build and runnable creation logic is well-designed with proper fallback handling. The bare exception catch on line 230 is appropriate here—if loading a prebuilt library fails for any reason, the builder correctly falls back to recompilation. The keyword-to-positional adapter ensures compatibility with the Definition interface.
| # 3. Build with TVM-FFI (compiles on first run, cached afterwards) | ||
| print("Building kernel with TVM-FFI...") | ||
| builder_registry = get_builder_registry() | ||
| runnable = builder_registry.build(definition, solution) | ||
| print(f"✓ Built successfully: {runnable.meta}") | ||
|
|
||
| # 4. Use in PyTorch (DLPack auto-conversion) | ||
| print("\n=== PyTorch Test ===") | ||
| n = 1000000 | ||
| a_torch = torch.randn(n, device="cuda", dtype=torch.float32) | ||
| b_torch = torch.randn(n, device="cuda", dtype=torch.float32) | ||
| c_torch = torch.empty(n, device="cuda", dtype=torch.float32) | ||
|
|
||
| runnable(a=a_torch, b=b_torch, c=c_torch, n=n) | ||
|
|
||
| expected = a_torch + b_torch | ||
| torch.testing.assert_close(c_torch, expected, rtol=1e-5, atol=1e-5) | ||
| print("✓ PyTorch: Result correct") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: runnable drops c and n; call will pass only a,b.
The adapter in TVMFFIBuilder (and CUDA builder) builds args from defn.inputs only, so kwargs c and n are ignored. Your kernel signature requires (a, b, c, n) and will fail. Fix the adapter to include outputs and axes in a defined order (inputs + outputs + axes), or change the example to match a two‑arg typed function.
- runnable(a=a_torch, b=b_torch, c=c_torch, n=n)
+ # After fixing adapter to pass inputs+outputs+axes:
+ runnable(a=a_torch, b=b_torch, c=c_torch, n=n)Follow-up suggested change in builder (flashinfer_bench/compile/builders/tvm_ffi_builder.py):
- arg_order = list(defn.inputs.keys())
- def _kw_adapter(**kwargs):
- args = [kwargs[name] for name in arg_order]
- return fn(*args)
+ # Pass inputs + outputs + named axes (if provided) to match C/CUDA signature.
+ arg_order = [*defn.inputs.keys(), *defn.outputs.keys(), *sorted(getattr(defn, "axes", {}).keys())]
+ def _kw_adapter(**kwargs):
+ args = [kwargs[name] for name in arg_order if name in kwargs]
+ return fn(*args)Also consider documenting the expected argument order for C/CUDA entry points.
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In examples/tvm_ffi_example.py around lines 65 to 83, the runnable call
currently passes a, b, c, n as kwargs but the TVMFFI/CUDA adapter builds
argument lists from definition.inputs only, so c and n are dropped and the
kernel signature (a, b, c, n) will fail; fix the adapter in
flashinfer_bench/compile/builders/tvm_ffi_builder.py (and the CUDA builder) to
build positional args in the canonical order inputs + outputs + axes (or inputs
+ outputs + scalar/axis params) and ensure the runnable wrapper accepts kwargs
and maps them into that ordered positional list before calling the C entry
point, and update the example or add docs to clearly state the required argument
ordering for C/CUDA entry points.
| def _verify_tvm_ffi() -> bool: | ||
| """Check if tvm_ffi is available.""" | ||
| try: | ||
| import tvm_ffi | ||
| import tvm_ffi.cpp | ||
|
|
||
| return True | ||
| except ImportError: | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix unused import flagged by linter.
The import tvm_ffi.cpp on line 27 is flagged as unused by the linter. Since this import is only for verification purposes and tvm_ffi.cpp is imported again in the _build method where it's actually used, you can remove it here.
Apply this diff:
def _verify_tvm_ffi() -> bool:
"""Check if tvm_ffi is available."""
try:
import tvm_ffi
- import tvm_ffi.cpp
return True
except ImportError:
return False📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _verify_tvm_ffi() -> bool: | |
| """Check if tvm_ffi is available.""" | |
| try: | |
| import tvm_ffi | |
| import tvm_ffi.cpp | |
| return True | |
| except ImportError: | |
| return False | |
| def _verify_tvm_ffi() -> bool: | |
| """Check if tvm_ffi is available.""" | |
| try: | |
| import tvm_ffi | |
| return True | |
| except ImportError: | |
| return False |
🧰 Tools
🪛 GitHub Actions: .github/workflows/linting.yaml
[error] 27-27: F401 'tvm_ffi.cpp' imported but unused. Remove unused import to satisfy linter.
🪛 Ruff (0.14.3)
29-29: Consider moving this statement to an else block
(TRY300)
🤖 Prompt for AI Agents
In flashinfer_bench/compile/builders/tvm_ffi_builder.py around lines 23 to 31,
the helper _verify_tvm_ffi currently imports tvm_ffi.cpp but that import is
unused and flagged by the linter; modify the function to only import tvm_ffi
(not tvm_ffi.cpp) inside the try block and return True on success, leaving the
ImportError handling unchanged so the actual tvm_ffi.cpp import remains in the
_build method where it is used.
| from flashinfer_bench.compile.builders.tvm_ffi_builder import TVMFFIBuilder, _verify_tvm_ffi | ||
| from flashinfer_bench.data import BuildSpec, Definition, Solution, SourceFile, SupportedLanguages | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Optional: verify build API name.
Docs showcase tvm_ffi.cpp.load_inline; your builder uses tvm_ffi.cpp.build_inline (which may be newer). Please verify the API name/signature to prevent runtime errors.
Cites for current docs indicating load_inline and installation usage. (tvm.apache.org)
🌐 Web query:
Is `tvm_ffi.cpp.build_inline(...)` the correct and public API (vs `tvm_ffi.cpp.load_inline(...)`), and what is its current Python signature?
💡 Result:
Short answer: use tvm_ffi.cpp.load_inline (that is the documented convenience API). tvm_ffi.cpp.build_inline was an internal/older helper in the codebase and the project is moving toward build / build_and_load APIs; the recommended/public inline loader is tvm_ffi.cpp.load_inline. [1][2]
Approximate current Python signature for the inline loader (same shape as typical torch.utils.cpp_extension.load_inline):
tvm_ffi.cpp.load_inline(
name,
cpp_sources,
cuda_sources=None,
functions=None,
extra_cflags=None,
extra_cuda_cflags=None,
extra_ldflags=None,
extra_include_paths=None,
build_directory=None,
verbose=False,
with_cuda=None,
is_python_module=True,
with_pytorch_error_handling=True,
keep_intermediates=True
). [1][3]
Sources:
- TVM-FFI quick start / inline loading docs showing tvm_ffi.cpp.load_inline as the convenience API. [1]
- TVM‑FFI repository / recent commits and notes showing evolution toward build / build_and_load and replacing older internal build_inline usage. [2]
If you want, I can fetch the exact def signature from the current repo tree and paste it verbatim.
---
🏁 Script executed:
```shell
#!/bin/bash
# Locate and examine the tvm_ffi_builder.py file
fd -e py "tvm_ffi_builder" --exec cat {} \;
Length of output: 16707
🏁 Script executed:
#!/bin/bash
# Search for usages of build_inline and load_inline in the codebase
rg -n "build_inline|load_inline" --type=py -C3Length of output: 2080
Replace tvm_ffi.cpp.build_inline() with the documented public API tvm_ffi.cpp.load_inline().
The recommended and public inline loader is tvm_ffi.cpp.load_inline, while build_inline is an older/internal helper. Line 242 uses the deprecated API with incompatible parameter names (output_dir should align with the public API's build_directory and other parameters). Refactor to use the documented load_inline API to prevent runtime errors and ensure forward compatibility.
🤖 Prompt for AI Agents
In tests/compile/test_tvm_ffi_builder.py around lines 7 to 9 (and update at the
call site around line 242), replace the deprecated tvm_ffi.cpp.build_inline()
usage with the documented public API tvm_ffi.cpp.load_inline(); change the
keyword argument output_dir to build_directory and align any other parameter
names to the load_inline signature, adjust how the return value is consumed if
load_inline returns a different object/shape than build_inline, and run the test
to ensure the new API call succeeds without runtime keyword/name mismatches.
This PR adds TVMFFIBuilder and PrebuiltLibraryManager, initializing the integration of TVM-FFI.
Summary by CodeRabbit
Release Notes
New Features
Documentation
Tests