-
Notifications
You must be signed in to change notification settings - Fork 332
[FFI] Use tvm ffi as the default execution backend #1259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
* Refactored FFI type definitions in runtime and code generation files to use `TVMFFIAny` instead of `TVMValue`, enhancing type clarity. * Updated function registration in `runtime.cc` to utilize canonical names for better consistency. * Simplified argument handling in the `simplify` transformation, ensuring unused buffer parameters are removed only when simplification is enabled. * Adjusted autotuner and profiler parameters to standardize the execution backend to `tvm_ffi`, improving clarity in backend selection. * Removed obsolete `adapt_torch2tvm` function from tensor utilities to streamline the codebase and reduce complexity.
* Updated the TVM submodule to commit cdc2aced, ensuring compatibility with recent changes. * Added functionality to print kernel source in `example_blocksparse_gemm.py` for better debugging. * Commented out the main execution call in test files to prevent unintended execution during testing. * Introduced `tilelang.disable_cache()` in various test files to streamline testing and avoid cache-related issues. * Refactored kernel source retrieval methods to improve clarity and consistency across different execution backends.
* Removed unused import of `tilelang.testing` in `test_example_blocksparse_gemm.py` to streamline the code. * Reformatted several lines in `arg_binder.cc`, `make_packed_api.cc`, `tvm_ffi.py`, and `adapter.py` for improved readability and consistency. * Updated comments and spacing in `tvm_ffi.py` to enhance clarity without altering functionality.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Caution Review failedThe pull request is closed. WalkthroughPivot runtime and JIT to a TVM‑FFI–first model: add backend resolution, TVM‑FFI adapter and FFI‑packed APIs, NULL‑safe DLTensor binding, host/device source separation, a TileLang C host codegen, CUDA stream access‑policy FFI hooks, and extensive cache, adapter, and test updates. (50 words) Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User
participant JIT as tilelang.jit
participant Resolver as execution_backend
participant Adapter as TVMFFIKernelAdapter
participant TVM as tvm.runtime
User->>JIT: compile(func, target="cuda", execution_backend="auto")
JIT->>Resolver: resolve_execution_backend("auto", target)
Resolver-->>JIT: "tvm_ffi"
JIT->>Adapter: instantiate TVMFFIKernelAdapter(...)
Adapter->>Adapter: _process_dynamic_symbolic()
Adapter-->>JIT: return kernel object
User->>Adapter: call kernel(torch_tensors)
Adapter->>Adapter: get_current_device()/get_current_stream()
Adapter->>TVM: invoke executable(device_tensors)
TVM-->>Adapter: outputs
Adapter-->>User: torch tensors
sequenceDiagram
autonumber
participant MakePackedAPI
participant ArgBinder
participant Runtime as runtime.cc
MakePackedAPI->>ArgBinder: BindDLTensor(handle,...)
Note right of ArgBinder: insert is_null guard\nuse conditional loads to avoid NULL deref
ArgBinder-->>MakePackedAPI: bound, guarded tensor values
MakePackedAPI->>Runtime: def_packed(tl::tvm_tensormap_create_..., ...)
Runtime-->>MakePackedAPI: registration complete
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Areas needing focused review:
Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (6)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
tilelang/jit/__init__.py (1)
46-65: execution_backend types and docs are out of sync with actual behaviorAcross
compile,par_compile, andjit:
The
Literal[...]annotations still mention"dlpack"and omit"tvm_ffi"(the new default) and"torch"(required in the metal path), e.g.:
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "tvm_ffi"Docstrings still say “Defaults to
\"cython\"”, while the code now defaults to"tvm_ffi".This will confuse users and static type checkers and hides that
"torch"is required for metal targets (see theis_metal_targetassertion).Consider:
- Updating the Literals everywhere
execution_backendappears to the actual supported set (e.g. including"tvm_ffi"and"torch"and dropping"dlpack"if it’s really gone), and- Fixing the docstrings to match the new default
"tvm_ffi".Also applies to: 98-117, 169-173, 284-305, 298-327
tilelang/autotuner/tuner.py (1)
574-580: Minor issues in cache-saving condition and autotune docstring
The condition
if self.compile_args.execution_backend in ("torch"):works but is unconventional:
("torch")is just a string, so this is a substring check. For clarity (and to avoid surprises) it should be== "torch"orin ("torch",).The warning message still says:
"DLPack backend does not support cache saving to disk."but the condition now checks for
"torch". That’s misleading; update the text to mention the torch backend instead (or whatever backend is actually being special-cased).In the
autotunedocstring, the documented parameterexecution_backend : Literal["tvm_ffi", "ctypes", "cython"], optionaldoes not exist in the function signature. This should either be removed or the API extended to actually accept
execution_backend.Also applies to: 711-713
tilelang/jit/adapter/nvrtc/adapter.py (1)
93-104: Fix NVRTC get_kernel_source signature to match JITKernel usage
JITKernel.get_kernel_source()now callsself.adapter.get_kernel_source(kernel_only=kernel_only)for"nvrtc", butNVRTCKernelAdapter.get_kernel_sourcetakes nokernel_onlyparameter. This will raise aTypeErrorwheneverget_kernel_source()is used with the NVRTC backend (e.g., PTX/SASS generation).You can make NVRTC conform to the common adapter API by adding a
kernel_onlyflag (even if it’s ignored for now), keeping current behavior:- def get_kernel_source(self) -> str | None: - """Get the CUDA kernel source code. + def get_kernel_source(self, kernel_only: bool = False) -> str | None: + """Get the CUDA kernel source code. + + Parameters + ---------- + kernel_only : bool, optional + Ignored for NVRTC; only device source is available. Returns ------- Optional[str] The kernel source code, or None if not available """ - return self.device_kernel_source + return self.device_kernel_sourceAlso, in
from_database,host_kernel_sourceandpass_configsare currently only stored and not used. That’s fine for correctness, but if they’re intentionally unused you might either wire them into reconstruction logic (for symmetry with other backends) or mark them with a leading underscore to silence linters.Also applies to: 171-179
tilelang/jit/kernel.py (1)
602-603: Fix kernel_source property for from_database JITKernel instancesThe
kernel_sourceproperty still falls back toself.adapter.kernel_global_sourcewhenself.artifactis falsy:def kernel_source(self) -> str: return self.artifact.kernel_source if self.artifact else self.adapter.kernel_global_sourceAll adapters in this PR have moved away from
kernel_global_sourcein favor ofhost_kernel_source/device_kernel_sourceandget_kernel_source(...). For kernels created viaJITKernel.from_database,self.artifactisNone, so this property will raiseAttributeErroronself.adapter.kernel_global_source.A minimal fix that matches the new adapter API would be:
@property def kernel_source(self) -> str: - return self.artifact.kernel_source if self.artifact else self.adapter.kernel_global_source + if self.artifact: + return self.artifact.kernel_source + # Fall back to adapter, keeping the historical “kernel only” behavior. + return self.adapter.get_kernel_source(kernel_only=True)This keeps existing behavior for freshly compiled kernels and makes
from_databaseinstances work with the updated adapters.
♻️ Duplicate comments (2)
examples/gdn/test_example_gdn_compilation.py (1)
200-201: Test harness inconsistency across the PR.Same issue as in
test_example_blocksparse_gemm.py—this test now bypasses the centralized test runner. Please see the earlier comment about ensuring consistency in test harness changes across the PR.testing/python/dynamic/test_tilelang_dynamic_symbolic.py (1)
517-517: Test harness consistency note.This test now uses
tilelang.testing.main(), which is the opposite pattern fromtest_example_blocksparse_gemm.pyandtest_example_gdn_compilation.py. See the earlier comment about test harness consistency.
🧹 Nitpick comments (5)
examples/blocksparse_gemm/example_blocksparse_gemm.py (1)
169-169: Consider making kernel source printing optional.The unconditional print of the entire kernel source may produce verbose console output. For better user experience, consider making this conditional on a
--verboseor--debugflag.Apply this diff to add a verbose flag:
parser.add_argument( "--use_autotune", action="store_true", default=False, help="Whether to use autotune") +parser.add_argument( + "--verbose", action="store_true", default=False, help="Print kernel source code")Then make the print conditional:
block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})") - print(kernel.get_kernel_source()) + if args.verbose: + print(kernel.get_kernel_source())testing/python/debug/test_tilelang_debug_print.py (1)
16-17: Global cache disabling may leak outside this test module
tilelang.disable_cache()is a global switch, so once this runs, caching stays disabled for the rest of the process. If that’s intentional (to avoid any cached kernels affecting debug-print behavior across the entire test run), this is fine; otherwise, consider scoping it (e.g., re‑enable after the test or using a context helper) so other tests keep exercising the caching path.tilelang/jit/adapter/base.py (1)
7-8: Stream/device helper thunks are reasonable; consider narrowing broad excepts or documenting themThe new
get_current_stream_functorandget_current_device_functorhelpers are a useful abstraction: they respect PyTorch’s current device/stream at call time and degrade to CPU/0 when CUDA isn’t available.Both helpers swallow all
Exceptions and silently fall back:except Exception: ...Given these are just helpers, that’s defensible, but static analysis (BLE001) will complain. If you want to keep the behavior but quiet the linter, consider either:
- Catching the specific expected exceptions (e.g.
RuntimeError,AttributeError), or- Adding an explicit inline
# noqa: BLE001with a short comment explaining why a broad catch is intentional.Also applies to: 50-71, 72-88
tilelang/jit/kernel.py (1)
54-76: Align execution_backend docs and verify non–tvm_ffi codegen gating
- The constructor now defaults
execution_backendto"tvm_ffi"(Line 58), but the docstring still claims the default is"cython"(Line 76). That’s misleading for users and should be updated.- In
_compile_and_create_adapter, bothenable_host_codegenandenable_device_compileare true only whenexecution_backend == "tvm_ffi"(Lines 227‑230). The ctypes/cython/nvrtc branches still passartifact.host_modandartifact.device_modinto their adapters, so please confirmtilelang.lower(...)still produces suitablehost_mod/device_modfor those backends under this gating.If the non‑tvm_ffi backends do require host/device IR modules, you may want something like:
- enable_host_codegen = execution_backend == "tvm_ffi" - enable_device_compile = execution_backend == "tvm_ffi" + enable_host_codegen = execution_backend in {"tvm_ffi", "ctypes", "cython", "nvrtc"} + enable_device_compile = execution_backend in {"tvm_ffi", "ctypes", "cython", "nvrtc"}(or whichever subset actually needs them).
Also applies to: 227-236, 241-297
tilelang/cache/kernel_cache.py (1)
41-45: Tighten backend selection and exception logging in KernelCacheFunctionally the new tvm_ffi caching path looks fine, but two small robustness issues stand out:
Backend source selection in
_load_kernel_from_disk
_load_kernel_from_disktakes anexecution_backendargument (Line 351) but chooseskernel_lib_pathbased onself.execution_backend(Lines 376‑381). Sinceself.execution_backendis mutable state set in_generate_key, this is fragile if a singleKernelCacheinstance is used with different backends concurrently.Consider basing the path selection on the method argument instead:
if self.execution_backend == "nvrtc":
if execution_backend == "nvrtc": kernel_lib_path = KERNEL_CUBIN_PATH
elif self.execution_backend == "tvm_ffi":
elif execution_backend == "tvm_ffi": kernel_lib_path = EXECUTABLE_PATH else: kernel_lib_path = KERNEL_LIB_PATH
Bare Exception catches with
logger.errorThe various try/except blocks around disk I/O (
_save_kernel_to_diskand_load_kernel_from_disk) catch bareExceptionand log viaself.logger.error(...)(e.g., Lines 278‑280, 294‑296, 397‑399, 405‑407). For debugging cache corruption or filesystem issues, it would be more useful to:
- Narrow the exception type where reasonable (e.g.,
OSError/IOError), and/or- Use
logger.exception(...)to preserve the traceback.For example:
except Exception as e:self.logger.error(f"Error saving kernel source code to disk: {e}")
except OSError:self.logger.exception("Error saving kernel source code to disk")Similar adjustments can be applied to the other I/O blocks.These are non‑blocking but will make cache behavior easier to reason about and debug.
Also applies to: 69-79, 240-245, 270-299, 345-383, 392-407
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (31)
3rdparty/tvm(1 hunks)examples/blocksparse_gemm/example_blocksparse_gemm.py(1 hunks)examples/blocksparse_gemm/test_example_blocksparse_gemm.py(1 hunks)examples/gdn/example_chunk_o_bwd.py(0 hunks)examples/gdn/test_example_gdn_compilation.py(3 hunks)src/runtime/runtime.cc(2 hunks)src/target/codegen_cpp.cc(2 hunks)src/transform/arg_binder.cc(5 hunks)src/transform/lower_hopper_intrin.cc(1 hunks)src/transform/make_packed_api.cc(13 hunks)src/transform/simplify.cc(1 hunks)testing/python/debug/test_tilelang_debug_print.py(1 hunks)testing/python/dynamic/test_tilelang_dynamic_symbolic.py(1 hunks)testing/python/jit/test_tilelang_jit_gemm_ctypes.py(0 hunks)testing/python/language/test_tilelang_language_alloc.py(1 hunks)tilelang/autotuner/param.py(2 hunks)tilelang/autotuner/tuner.py(3 hunks)tilelang/cache/__init__.py(1 hunks)tilelang/cache/kernel_cache.py(11 hunks)tilelang/contrib/dlpack.py(0 hunks)tilelang/jit/__init__.py(4 hunks)tilelang/jit/adapter/__init__.py(1 hunks)tilelang/jit/adapter/base.py(2 hunks)tilelang/jit/adapter/ctypes/adapter.py(6 hunks)tilelang/jit/adapter/cython/adapter.py(6 hunks)tilelang/jit/adapter/dlpack.py(0 hunks)tilelang/jit/adapter/nvrtc/adapter.py(5 hunks)tilelang/jit/adapter/tvm_ffi.py(1 hunks)tilelang/jit/kernel.py(16 hunks)tilelang/profiler/__init__.py(1 hunks)tilelang/utils/tensor.py(0 hunks)
💤 Files with no reviewable changes (5)
- tilelang/utils/tensor.py
- testing/python/jit/test_tilelang_jit_gemm_ctypes.py
- examples/gdn/example_chunk_o_bwd.py
- tilelang/contrib/dlpack.py
- tilelang/jit/adapter/dlpack.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.080Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.080Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
examples/blocksparse_gemm/test_example_blocksparse_gemm.pytesting/python/debug/test_tilelang_debug_print.py
🧬 Code graph analysis (13)
tilelang/jit/adapter/base.py (5)
tilelang/jit/adapter/tvm_ffi.py (2)
func(195-262)get_kernel_source(311-316)tilelang/jit/adapter/ctypes/adapter.py (1)
get_kernel_source(294-300)tilelang/jit/adapter/cython/adapter.py (1)
get_kernel_source(384-391)tilelang/jit/adapter/nvrtc/adapter.py (1)
get_kernel_source(171-179)tilelang/jit/kernel.py (1)
get_kernel_source(421-432)
tilelang/jit/adapter/__init__.py (1)
tilelang/jit/adapter/tvm_ffi.py (1)
TVMFFIKernelAdapter(23-321)
examples/gdn/test_example_gdn_compilation.py (2)
tilelang/env.py (1)
disable_cache(275-276)examples/gdn/example_chunk_o_bwd.py (2)
kernel(157-395)tilelang_chunk_o_bwd_dqkwg(115-397)
testing/python/language/test_tilelang_language_alloc.py (6)
tilelang/jit/adapter/ctypes/adapter.py (1)
get_kernel_source(294-300)tilelang/jit/adapter/tvm_ffi.py (1)
get_kernel_source(311-316)tilelang/jit/adapter/base.py (1)
get_kernel_source(93-97)tilelang/jit/adapter/cython/adapter.py (1)
get_kernel_source(384-391)tilelang/jit/adapter/nvrtc/adapter.py (1)
get_kernel_source(171-179)tilelang/jit/kernel.py (1)
get_kernel_source(421-432)
testing/python/debug/test_tilelang_debug_print.py (2)
tilelang/env.py (1)
disable_cache(275-276)tilelang/jit/__init__.py (2)
compile(46-95)compile(222-248)
tilelang/jit/adapter/nvrtc/adapter.py (1)
tilelang/jit/adapter/wrapper.py (8)
host_func(703-713)wrap(202-203)wrap(1203-1222)wrap(1230-1243)update_lib_code(622-671)update_lib_code(897-917)update_lib_code(1116-1135)update_lib_code(1168-1170)
examples/blocksparse_gemm/example_blocksparse_gemm.py (7)
tilelang/jit/adapter/ctypes/adapter.py (1)
get_kernel_source(294-300)tilelang/jit/adapter/tvm_ffi.py (1)
get_kernel_source(311-316)tilelang/jit/adapter/base.py (1)
get_kernel_source(93-97)tilelang/jit/adapter/cython/adapter.py (1)
get_kernel_source(384-391)tilelang/jit/adapter/nvrtc/adapter.py (1)
get_kernel_source(171-179)tilelang/jit/kernel.py (1)
get_kernel_source(421-432)tilelang/jit/param.py (1)
get_kernel_source(27-28)
tilelang/jit/adapter/cython/adapter.py (5)
tilelang/jit/adapter/wrapper.py (7)
wrap(202-203)wrap(1203-1222)wrap(1230-1243)update_lib_code(622-671)update_lib_code(897-917)update_lib_code(1116-1135)update_lib_code(1168-1170)tilelang/jit/adapter/ctypes/adapter.py (1)
get_kernel_source(294-300)tilelang/jit/adapter/tvm_ffi.py (1)
get_kernel_source(311-316)tilelang/jit/adapter/nvrtc/adapter.py (1)
get_kernel_source(171-179)tilelang/jit/kernel.py (2)
get_kernel_source(421-432)params(598-599)
tilelang/jit/adapter/ctypes/adapter.py (2)
tilelang/jit/kernel.py (1)
params(598-599)tilelang/jit/adapter/base.py (1)
_legalize_result_idx(20-44)
tilelang/jit/kernel.py (5)
tilelang/jit/adapter/nvrtc/adapter.py (3)
NVRTCKernelAdapter(25-259)from_database(94-147)get_kernel_source(171-179)tilelang/jit/adapter/tvm_ffi.py (4)
TVMFFIKernelAdapter(23-321)from_database(267-297)get_kernel_source(311-316)get_host_source(299-303)tilelang/jit/adapter/base.py (2)
BaseKernelAdapter(10-100)get_kernel_source(93-97)tilelang/jit/adapter/ctypes/adapter.py (3)
from_database(113-162)CtypesKernelAdapter(19-300)get_kernel_source(294-300)tilelang/jit/adapter/cython/adapter.py (2)
from_database(149-208)get_kernel_source(384-391)
src/transform/make_packed_api.cc (2)
src/target/codegen_cpp.cc (8)
VisitExpr_(166-179)VisitExpr_(166-167)VisitExpr_(375-414)VisitExpr_(375-376)VisitExpr_(452-455)VisitExpr_(452-453)VisitExpr_(457-460)VisitExpr_(457-458)tilelang/language/tir/op.py (1)
tvm_struct_get(414-436)
tilelang/jit/adapter/tvm_ffi.py (3)
tilelang/utils/target.py (1)
determine_target(62-123)tilelang/jit/adapter/base.py (6)
BaseKernelAdapter(10-100)_legalize_result_idx(20-44)_post_init(99-100)_convert_torch_func(47-48)get_current_device_functor(73-88)get_kernel_source(93-97)tilelang/engine/param.py (1)
KernelParam(12-103)
tilelang/cache/kernel_cache.py (2)
tilelang/jit/kernel.py (5)
export_library(609-626)kernel_source(602-603)get_host_source(434-439)get_kernel_source(421-432)from_database(143-183)tilelang/jit/adapter/tvm_ffi.py (4)
func(195-262)get_host_source(299-303)get_kernel_source(311-316)from_database(267-297)
🪛 Ruff (0.14.4)
tilelang/jit/adapter/base.py
66-66: Do not catch blind exception: Exception
(BLE001)
85-85: Do not catch blind exception: Exception
(BLE001)
tilelang/jit/adapter/__init__.py
2-2: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/jit/adapter/nvrtc/adapter.py
103-103: Unused class method argument: pass_configs
(ARG003)
tilelang/jit/adapter/tvm_ffi.py
176-176: Do not catch blind exception: Exception
(BLE001)
199-201: Avoid specifying long messages outside the exception class
(TRY003)
221-221: Unpacked variable ref_id is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
233-235: Avoid specifying long messages outside the exception class
(TRY003)
251-253: Avoid specifying long messages outside the exception class
(TRY003)
277-277: Unused class method argument: compile_flags
(ARG003)
tilelang/cache/kernel_cache.py
278-278: Do not catch blind exception: Exception
(BLE001)
279-279: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
294-294: Do not catch blind exception: Exception
(BLE001)
295-295: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
397-397: Do not catch blind exception: Exception
(BLE001)
398-398: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
405-405: Do not catch blind exception: Exception
(BLE001)
406-406: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (20)
src/transform/arg_binder.cc (4)
120-122: LGTM: Clear error messaging for dtype mismatches.The error message clearly indicates both expected and provided types, making debugging easier.
192-196: LGTM: Improved error messages avoid runtime expression dumping.The error messages are now more user-friendly by avoiding TIR expression dumps and providing clear, actionable feedback about NULL pointers, ndim mismatches, and dtype incompatibility.
Also applies to: 220-223, 227-231
342-344: LGTM: More descriptive stride error message.Explicitly mentioning "non-compact strides" makes the error message clearer for debugging.
99-119: Based on my comprehensive verification, I need to mark this review comment as incorrect.The FP8 variant compatibility logic in the lambda is intentionally asymmetric and correct.
The concern about bidirectional compatibility misunderstands the purpose of the
dtype_compatiblelambda. E4M3FN and E4M3FNUZ use different exponent biases and differ in the special values they can represent—these are not interchangeable formats.The current logic deliberately allows binding when the expected type is the base variant (e.g.,
float8_e4m3), but rejects binding when expected is a variant (e.g.,float8_e4m3fnexpected,float8_e4m3provided). This asymmetry is intentional protection: conversion between variants like e4m3fn and e4m3fnuz requires explicit normalization, which cannot be silently handled during buffer binding.The codebase treats each FP8 variant as a distinct type code with different conversion requirements. Bidirectional compatibility would mask necessary type conversions and introduce silent correctness issues.
Likely an incorrect or invalid review comment.
3rdparty/tvm (1)
1-1: Verify TVM submodule commit is compatible with FFI migration changes.The TVM commit (
cdc2aced0d87cc6e5e24811bd964efd2ce2d0729) is valid and exists in the apache/tvm repository. However, the commit message indicates it focuses on CUDA function attribute refactoring and error handling in C host code generation. Confirm this commit integrates properly with the TVMFFIKernelAdapter and FFI backend adoption across the codebase, particularly verifying there are no breaking changes to the TVM FFI wrapper layer.testing/python/language/test_tilelang_language_alloc.py (1)
154-154: LGTM: Kernel-only source retrieval is appropriate here.The change to
kernel_only=Truecorrectly retrieves only the device kernel source, which is sufficient for validating initializer constants. This aligns with the PR's separation of host and device kernel sources.tilelang/profiler/__init__.py (1)
278-278: LGTM: Direct tensor passing aligns with TVM FFI backend.The removal of
adapt_torch2tvmconversion correctly reflects that the new TVM FFI backend can handle PyTorch tensors directly, simplifying the profiling path.examples/gdn/test_example_gdn_compilation.py (2)
110-110: Cache disabling is appropriate for compilation tests.Calling
tilelang.disable_cache()ensures this test exercises the actual compilation path rather than retrieving cached results, which is correct for a compilation test.
122-127: Clarify the purpose of redundant kernel instantiations.The same kernel is instantiated three times with identical parameters. Is this intentionally testing compilation idempotency or robustness, or is this leftover debug code?
Consider adding a comment explaining the intent, or removing the redundant instantiations if they're not needed:
- kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, state_dtype, chunk_size, 1.0, use_g, True, - block_DK, block_DV, threads, num_stages) - kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, - gate_dtype, state_dtype, chunk_size, 1.0, use_g, True, - block_DK, block_DV, threads, num_stages) - + # Test compilation idempotency by instantiating the kernel multiple times + for _ in range(3): + kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, + gate_dtype, state_dtype, chunk_size, 1.0, use_g, True, + block_DK, block_DV, threads, num_stages)tilelang/autotuner/param.py (2)
45-45: LGTM: Default backend updated to tvm_ffi.The change from
"cython"to"tvm_ffi"as the default execution backend aligns with the PR's objective of making TVM FFI the default execution backend.
211-211: LGTM: Consistent backend parameter update.The signature change in
_load_kernel_from_diskcorrectly mirrors the updated default backend throughout the codebase.tilelang/jit/adapter/__init__.py (2)
2-2: LGTM: Adapter imports updated for TVM FFI backend.The addition of
TVMFFIKernelAdapterand removal of the legacy DLPack adapter correctly reflects the migration to the new TVM FFI execution backend.Note: The Ruff hint about the unused
noqadirective is a false positive—F401suppressions in__init__.pyfiles are intentional for re-exports.
6-6: LGTM: MetalKernelAdapter added.The addition of
MetalKernelAdapterexpands backend support as part of the adapter refactoring.tilelang/cache/__init__.py (1)
21-21: LGTM: Cache function updated to use tvm_ffi backend.The default execution backend change from
"cython"to"tvm_ffi"is consistent with the broader migration across the codebase and aligns with the PR's objectives.examples/blocksparse_gemm/test_example_blocksparse_gemm.py (1)
9-10: All 6 files with commentedtilelang.testing.main()follow an identical pattern:Every single file (across both
testing/andexamples/) has replaced the centralized runner with a direct test function call in__main__. This uniformity suggests the pattern is intentional, not accidental, but raises a critical question: Why are these 6 tests deliberately bypassing the centralized test harness while 100+ others use it?The original review comment is valid. The inconsistency is real and systematic. However, the fact that all 6 deviate identically could indicate:
- Incomplete refactoring – This may be work-in-progress where some tests have been converted but others haven't
- Intentional exemption – These tests may require different handling for a documented reason
- Test discovery risk – If
tilelang.testing.main()handles centralized registration/discovery for CI/CD, this bypass could cause tests to be skippedThe original review appropriately asks for clarification. The file under review is part of a systematic pattern, not an isolated anomaly.
Rewritten Review Comment:
Clarify the purpose of bypassing
tilelang.testing.main()across these 6 tests.This file is one of exactly 6 tests (out of 100+) that replace
tilelang.testing.main()with direct function invocation. All 6 follow an identical pattern, suggesting intent rather than accident. Is this a deliberate refactoring in progress, or are these tests intentionally exempt from centralized test discovery? Without clarification, this risks CI/CD test discovery issues and maintenance confusion. If intentional, document the reason; if unintended, restore the centralized runner to match the project standard.src/transform/simplify.cc (1)
243-276: Optional buffer-parameter pruning logic looks soundThe new
simplify_argumentspath correctly:
- Recomputes
used_buffers_from the (already simplified)func.- Retains buffer params when they’re used in the body or referenced from other buffer definitions (
used_in_buffer_def_).- Keeps scalar params unconditionally.
- Only rebuilds the
PrimFuncwhen a parameter was actually dropped.This matches the intent of removing only truly unused buffer parameters without disturbing scalar arguments.
src/target/codegen_cpp.cc (1)
206-216: TVMFFIAny migration in codegen is consistentUsing
TVMFFIAnyforret_val, stack allocation, and the casts passed intoTVMFuncCallis consistent with the FFI aliases, and thestatic_assertonalignof(TVMFFIAny)vsDLTensorshould keep thetvm_stack_alloca("array", ...)reinterpret-cast safe. I don’t see correctness issues in these changes.Please make sure the updated codegen paths are covered by an FFI-based end‑to‑end test (e.g., a small packed function call exercising
tvm_stack_allocaand return values) to catch any ABI/alignment regressions early.Also applies to: 231-243, 382-398
src/runtime/runtime.cc (1)
91-109: Canonical FFI symbol registration is good; double‑check logging macroRegistering the TMA helpers with the canonical names (
tl::tvm_tensormap_create_tiled/tl::tvm_tensormap_create_im2col) instead of string literals is a nice improvement for API consistency.The failure path now uses
LOG_FATAL << .... Please verify thatLOG_FATALis a valid logging macro in this project; most TVM-style code usesLOG(FATAL) << .... IfLOG_FATALisn’t defined, this will be a compile error and should be changed back toLOG(FATAL).Also applies to: 185-201
tilelang/autotuner/tuner.py (1)
140-167: Updated default execution_backend for autotuner compile args looks consistentSwitching
set_compile_argstoexecution_backend: Literal["tvm_ffi", "ctypes", "cython"] = "tvm_ffi"matches the broader move to
tvm_ffias the default and aligns with the updatedCompileArgstyping. No issues here.tilelang/jit/adapter/cython/adapter.py (1)
52-54: Host/device source split in Cython adapter looks consistentThe new
host_kernel_source/device_kernel_sourcefields, initialization viawrapper.wrap(self.get_kernel_source(kernel_only=True)), andget_kernel_source(kernel_only=...)behavior all line up with the updatedJITKernelandKernelCacheexpectations. No issues from a correctness standpoint.Also applies to: 81-97, 121-128, 148-165, 384-391
| PrimExpr cond = (v_type_code == expect_code && v_type_bits == expect_bits && | ||
| v_type_lanes == expect_lanes); | ||
|
|
||
| // Allow float8_e4m3 to match float8_e4m3fn/float8_e4m3fnuz at runtime. | ||
| if (buffer->dtype.is_float8_e4m3()) { | ||
| PrimExpr code_e4m3 = IntImm(DataType::UInt(8), DataType::kFloat8_e4m3); | ||
| PrimExpr code_e4m3fn = IntImm(DataType::UInt(8), DataType::kFloat8_e4m3fn); | ||
| PrimExpr code_e4m3fnuz = | ||
| IntImm(DataType::UInt(8), DataType::kFloat8_e4m3fnuz); | ||
| PrimExpr code_match = | ||
| (v_type_code == code_e4m3 || v_type_code == code_e4m3fn || | ||
| v_type_code == code_e4m3fnuz); | ||
| cond = cond || (code_match && v_type_bits == expect_bits && | ||
| v_type_lanes == expect_lanes); | ||
| } | ||
| // Allow float8_e5m2 to match float8_e5m2fnuz at runtime. | ||
| if (buffer->dtype.is_float8_e5m2()) { | ||
| PrimExpr code_e5m2 = IntImm(DataType::UInt(8), DataType::kFloat8_e5m2); | ||
| PrimExpr code_e5m2fnuz = | ||
| IntImm(DataType::UInt(8), DataType::kFloat8_e5m2fnuz); | ||
| PrimExpr code_match = | ||
| (v_type_code == code_e5m2 || v_type_code == code_e5m2fnuz); | ||
| cond = cond || (code_match && v_type_bits == expect_bits && | ||
| v_type_lanes == expect_lanes); | ||
| } | ||
| // Allow bool to match int8/uint8 at runtime, and also kDLBool(code=6). | ||
| if (buffer->dtype.is_bool()) { | ||
| PrimExpr code_int = IntImm(DataType::UInt(8), DataType::kInt); | ||
| PrimExpr code_uint = IntImm(DataType::UInt(8), DataType::kUInt); | ||
| PrimExpr code_kdlbool = IntImm(DataType::UInt(8), 6); | ||
| PrimExpr bits8 = IntImm(DataType::UInt(8), 8); | ||
| PrimExpr bits1 = IntImm(DataType::UInt(8), 1); | ||
| PrimExpr lanes_ok = (v_type_lanes == expect_lanes); | ||
| PrimExpr int8_ok = | ||
| (v_type_code == code_int && v_type_bits == bits8 && lanes_ok); | ||
| PrimExpr uint8_ok = | ||
| (v_type_code == code_uint && v_type_bits == bits8 && lanes_ok); | ||
| // Some frontends may tag bool tensors as kDLBool(code=6), commonly with | ||
| // bits=8 or bits=1. | ||
| PrimExpr kdlbool8_ok = | ||
| (v_type_code == code_kdlbool && v_type_bits == bits8 && lanes_ok); | ||
| PrimExpr kdlbool1_ok = | ||
| (v_type_code == code_kdlbool && v_type_bits == bits1 && lanes_ok); | ||
| // Also accept any dtype whose bitwidth=1, regardless of code, to be | ||
| // defensive. | ||
| PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok); | ||
| cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review overly permissive bool compatibility logic.
The bool compatibility logic at lines 286-288 accepts any dtype with 1-bit width regardless of type code:
PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok);
cond = cond || ... || bit1_ok;This defensive check is overly broad and could mask type errors by accepting exotic 1-bit types (e.g., hypothetical float1 or custom types) that may not have correct bool semantics. Consider restricting this to only the explicitly listed type codes (int8, uint8, kDLBool with 8/1 bits) rather than accepting arbitrary 1-bit types.
Apply this diff to remove the overly permissive check:
PrimExpr kdlbool1_ok =
(v_type_code == code_kdlbool && v_type_bits == bits1 && lanes_ok);
- // Also accept any dtype whose bitwidth=1, regardless of code, to be
- // defensive.
- PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok);
- cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok;
+ cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok;🤖 Prompt for AI Agents
In src/transform/arg_binder.cc around lines 242 to 289, the current bool
compatibility logic defines and uses a permissive bit1_ok check that accepts any
1-bit dtype regardless of type code; remove that overly broad check by deleting
the PrimExpr bit1_ok definition and stop OR-ing bit1_ok into cond so only the
explicit allowed codes (int8, uint8, and kDLBool with bits=8 or bits=1) combined
with lanes_ok are accepted.
| Call alloc_desc = Call(DataType::Handle(), builtin::tvm_stack_alloca(), | ||
| {StringImm("arg_value"), 16}); | ||
| {StringImm("tvm_ffi_any"), 16}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify consistency and allocation size across the codebase.
The type string change from "arg_value" to "tvm_ffi_any" aligns with the PR's FFI migration. However, please ensure that:
- All
tvm_stack_allocacalls have been consistently updated across the codebase - The size of
TVMFFIAnyis 8 bytes (same asTVMValue) to maintain the 128-byte allocation (16 elements × 8 bytes)
Run the following script to check for any remaining "arg_value" usages in stack allocations:
Additionally, please search for documentation on TVMFFIAny size to confirm the allocation correctness:
🏁 Script executed:
#!/bin/bash
# Search for any remaining tvm_stack_alloca calls that might still use "arg_value"
rg -nP 'tvm_stack_alloca\s*\(' -A2 -B1Length of output: 1595
🌐 Web query:
What is the size of TVMFFIAny in TVM FFI?
💡 Result:
TVMFFIAny is 16 bytes (128 bits) in size. [1][2]
I need to verify the intended allocation size for the TensorMap descriptor before finalizing the assessment.
Also, let me search for more information about TVMFFIAny and TVMValue size comparison:
🏁 Script executed:
#!/bin/bash
# Search for TensorMap descriptor size references and documentation
rg -nP 'TensorMap|tvm_ffi_any' -C3 | head -100Length of output: 6827
🏁 Script executed:
#!/bin/bash
# Check for any size constants or comments related to descriptor allocation
rg -nP '128|256' src/transform/lower_hopper_intrin.cc -C2Length of output: 321
🌐 Web query:
TVMFFIAny vs TVMValue size bytes TVM
💡 Result:
- TVMFFIAny: 16 bytes (128-bit). [1]
- TVMValue: 8 bytes on typical 64-bit platforms (it's a union whose payload fits in 64 bits). [2]
Sources: TVM-FFI docs for TVMFFIAny and TVM C API TVMValue. [1] https://tvm.apache.org/ffi/reference/cpp/generated/structTVMFFIAny.html [2] https://tvm.apache.org/docs/reference/api/doxygen/unionTVMValue.html
🌐 Web query:
TensorMap descriptor size bytes TVM CUDA Hopper
💡 Result:
The TensorMap descriptor is 1024 bits = 128 bytes. PTX documents the tensor-map object size as 1024 bits (b1024). The CUDA docs also describe tensor-map objects and alignment requirements (CUDA APIs require the tensorMap address alignment; some docs state 64‑byte alignment). [1][2][3]
Sources:
- PTX / Hopper ISA (tensormap .b1024 = 1024 bits). [1]
- CUDA Driver API tensor-map documentation (descriptor and alignment notes). [2]
- NVIDIA patent / implementation notes (describes 128‑byte opaque tensor map). [3]
Critical allocation size error: Change 16 elements to 8 elements for TVMFFIAny migration.
The migration from TVMValue (8 bytes) to TVMFFIAny (16 bytes) failed to account for the size difference. TVMFFIAny is 16 bytes while TVMValue is 8 bytes, and the TensorMap descriptor requires 128 bytes (1024 bits).
The current code allocates:
- 16 elements × 16 bytes = 256 bytes ❌
It should allocate:
- 8 elements × 16 bytes = 128 bytes ✓
Fix at src/transform/lower_hopper_intrin.cc:32:
{StringImm("tvm_ffi_any"), 8} // Changed from 16 to 8
This double over-allocation will cause memory corruption and crashes during TensorMap descriptor initialization on Hopper GPUs.
🤖 Prompt for AI Agents
In src/transform/lower_hopper_intrin.cc around lines 31 to 32, the allocation
call wrongly requests 16 elements of TVMFFIAny (16 bytes each) producing 256
bytes instead of the required 128 bytes for the TensorMap descriptor; change the
element count from 16 to 8 so the Call allocates {StringImm("tvm_ffi_any"), 8}
(8 × 16 = 128 bytes) to match the expected descriptor size and avoid
over-allocation and potential memory corruption.
| def get_kernel_source(self, kernel_only: bool = True) -> str: | ||
| if kernel_only: | ||
| return self.mod.imports[0].inspect_source() | ||
| else: | ||
| return self.mod.inspect_source() + "\n\n" + self.mod.imports[0].inspect_source() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify get_kernel_source semantics and adapter interface consistency
Changing the base adapter to:
def get_kernel_source(self, kernel_only: bool = True) -> str:
if kernel_only:
return self.mod.imports[0].inspect_source()
else:
return self.mod.inspect_source() + "\n\n" + self.mod.imports[0].inspect_source()is consistent with the idea of “kernel-only vs host+device” sources, but two things are worth double‑checking:
-
Module structure assumptions: this assumes every adapter’s
modalways has at least one import and thatimports[0]is the device module. If that’s not guaranteed by your codegen, this can raise or return the wrong source. If there are backends wheremod.importsmay be empty or ordered differently, they should overrideget_kernel_sourceexplicitly. -
Adapter interface alignment:
JITKernel.get_kernel_source(intilelang/jit/kernel.py) callsself.adapter.get_kernel_source(kernel_only=kernel_only)for backends including"ctypes","cython","nvrtc", and"tvm_ffi". Ensure every adapter implements a compatible signature:ctypes/cython/tvm_ffialready provideget_kernel_source(kernel_only: bool = False).- The
nvrtcadapter snippet you provided currently hasdef get_kernel_source(self) -> str | None:with nokernel_onlyparameter. If that’s still the case in this PR, it will raiseTypeError: get_kernel_source() got an unexpected keyword argument 'kernel_only'whenexecution_backend="nvrtc".
Either update the nvrtc adapter (and any other stragglers) to accept kernel_only: bool = False, or special‑case "nvrtc" in JITKernel.get_kernel_source so it calls the old signature.
I'll verify the critical concerns identified in the review comment by examining the adapter implementations and their signatures.
[/run_scripts]
Let me search for the adapter files in the repository structure:
[/run_scripts]
Let me be more specific and search for the key files mentioned in the review:
[/run_scripts]
Let me take a different approach and search for the get_kernel_source method across the codebase:
[/run_scripts]
Let me also check for the JITKernel class and its get_kernel_source implementation:
[/run_scripts]
Let me search more broadly for adapter-related files:
[/run_scripts]
Now let me examine the specific files mentioned in the review:
[/run_scripts]
budget:token_budget200000</budget:token_output>
🏁 Script executed:
#!/bin/bash
# Find adapter implementations and inspect get_kernel_source signatures
find . -type f -name "*.py" -path "*/adapter/*" | head -20Length of output: 524
🏁 Script executed:
#!/bin/bash
# Locate adapter files and get_kernel_source method implementations
fd -t f "\.py$" "adapter" | head -30Length of output: 94
🏁 Script executed:
#!/bin/bash
# Look for base.py, kernel.py and various adapter implementations
fd -t f "base\.py"
fd -t f "kernel\.py"
fd -t f "nvrtc.*\.py"
fd -t f "ctypes.*\.py"Length of output: 281
🏁 Script executed:
#!/bin/bash
# Search for get_kernel_source method definitions
rg -n "def get_kernel_source" --type pyLength of output: 646
🏁 Script executed:
#!/bin/bash
# Search for JITKernel class and how it calls get_kernel_source
rg -n -A 10 "class JITKernel" --type py
rg -n -B 2 -A 5 "\.get_kernel_source\(" --type pyLength of output: 74247
🏁 Script executed:
#!/bin/bash
# Find all Python files in jit/adapter directory
fd -t f "\.py$" "jit/adapter"
# Also find the jit directory structure
find . -type d -name "jit" 2>/dev/null
find . -type d -name "adapter" 2>/dev/nullLength of output: 157
🏁 Script executed:
#!/bin/bash
# Examine the base adapter file
cat "tilelang/jit/adapter/base.py" 2>/dev/null || echo "File not found in standard path"
# Check directory structure
ls -la tilelang/jit/adapter/ 2>/dev/null || ls -la tilelang/ 2>/dev/null || find . -maxdepth 3 -type d -name jitLength of output: 4699
Let me examine the nvrtc adapter and the JITKernel implementation in detail:
[/run_scripts]
🏁 Script executed:
#!/bin/bash
# Get nvrtc adapter implementation
cat tilelang/jit/adapter/nvrtc/adapter.py | head -200Length of output: 7979
[/run_scripts]
🏁 Script executed:
#!/bin/bash
# Check torch adapter
cat tilelang/jit/adapter/torch/metal.py | grep -A 10 "get_kernel_source" || echo "No get_kernel_source in torch adapter"
# Check JITKernel.get_kernel_source implementation
sed -n '421,435p' tilelang/jit/kernel.pyLength of output: 546
[/run_scripts]
🏁 Script executed:
#!/bin/bash
# Check ctypes adapter get_kernel_source
sed -n '294,310p' tilelang/jit/adapter/ctypes/adapter.py
# Check cython adapter get_kernel_source
sed -n '384,400p' tilelang/jit/adapter/cython/adapter.py
# Check tvm_ffi get_kernel_source
sed -n '311,327p' tilelang/jit/adapter/tvm_ffi.pyLength of output: 1147
Fix the nvrtc adapter signature to accept the kernel_only parameter
The review comment correctly identifies a critical interface mismatch. The nvrtc adapter has an incompatible signature that will cause TypeError at runtime.
Issues confirmed:
-
nvrtc adapter signature mismatch: The method at
tilelang/jit/adapter/nvrtc/adapter.py:171isdef get_kernel_source(self) -> str | None:without thekernel_onlyparameter. However,JITKernel.get_kernel_sourceattilelang/jit/kernel.py:430-431callsself.adapter.get_kernel_source(kernel_only=kernel_only)for the nvrtc backend. This will raiseTypeError: get_kernel_source() got an unexpected keyword argument 'kernel_only'when usingexecution_backend="nvrtc". -
Base adapter module structure assumption: The base implementation accesses
self.mod.imports[0]without guards. While child adapters (ctypes, cython, tvm_ffi) all override with their own implementations using stored sources, the nvrtc adapter relies directly onself.device_kernel_sourcewith noself.modstructure. This inconsistency suggests the module import structure isn't consistently guaranteed. -
Default parameter inconsistency: The base adapter defaults to
kernel_only=True, but all working child adapters (ctypes, cython, tvm_ffi) override withkernel_only=False, indicating the base implementation path is never actually exercised.
Required fix: Update tilelang/jit/adapter/nvrtc/adapter.py:171 to accept the kernel_only parameter consistent with other adapters (likely kernel_only: bool = False), and align its implementation logic with how it stores kernel sources.
🤖 Prompt for AI Agents
In tilelang/jit/adapter/nvrtc/adapter.py around line 171, the nvrtc adapter's
get_kernel_source signature lacks the kernel_only parameter and will raise a
TypeError when called with kernel_only by JITKernel; change the signature to
accept kernel_only: bool = False and implement branching: if kernel_only return
only the device/kernel source (self.device_kernel_source or equivalent), else
return the full combined source (host + device or module + imports) consistent
with other adapters; also defensively handle missing attributes (fall back to
empty string) so the method never raises due to absent self.mod or sources.
| host_kernel_source: str | None = None | ||
| device_kernel_source: str | None = None | ||
| lib: ctypes.CDLL | None = None # Compiled library handle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Populate host_kernel_source for ctypes backend to keep KernelCache working
In CtypesKernelAdapter.__init__, you now accept host_kernel_source / device_kernel_source, but for the normal compile path only device_kernel_source is provided from JITKernel. The code does:
self.host_kernel_source = host_kernel_source
self.device_kernel_source = device_kernel_source
...
self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True))
...
self.lib_generator.update_lib_code(self.wrapped_source)get_kernel_source(kernel_only=True) returns self.device_kernel_source, so wrapped_source is correctly derived. However, host_kernel_source remains None in this path.
Later, KernelCache._save_kernel_to_disk persists the host source with:
KernelCache._safe_write_file(
host_kernel_path, "w",
lambda file: file.write(kernel.adapter.get_kernel_source()))Since get_kernel_source() defaults to kernel_only=False, for ctypes it returns self.host_kernel_source, i.e. None for compiled kernels. Writing None to the file will raise a TypeError.
You can fix this by populating host_kernel_source from the wrapped source in __init__:
- self.wrapper.assign_host_module(host_mod)
- self.wrapper.assign_device_module(device_mod)
- self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True))
+ self.wrapper.assign_host_module(host_mod)
+ self.wrapper.assign_device_module(device_mod)
+ self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True))
+ # Keep a copy of the wrapped host source for inspection/caching.
+ self.host_kernel_source = self.wrapped_sourceThis keeps existing behavior for the compiled path, while the from_database path still overrides host_kernel_source and wrapped_source as needed.
Also applies to: 52-70, 99-106, 112-130, 294-300
| for i, param in enumerate(params): | ||
| if param in buffer_map: | ||
| buffer = buffer_map[param] | ||
| for j, shape in enumerate(buffer.shape): | ||
| if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and | ||
| (shape not in params)): | ||
| dynamic_symbolic_map[shape] = (0, i, j) | ||
| for i, param in enumerate(params): | ||
| if param in buffer_map: | ||
| buffer = buffer_map[param] | ||
| for j, stride in enumerate(buffer.strides): | ||
| if (isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and | ||
| (stride not in params)): | ||
| dynamic_symbolic_map[stride] = (1, i, j) | ||
| return dynamic_symbolic_map | ||
|
|
||
| def _convert_torch_func(self) -> Callable[..., Any]: | ||
| # Capture thunks that reflect Torch's current stream and device. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dynamic shape symbols coming from PrimFunc params make outputs zero-rank
When a PrimFunc carries dynamic dims via explicit scalar params (e.g. n or m), those tir.Var symbols appear in both func.params and the buffer shapes. The guard (shape not in params) prevents _process_dynamic_symbolic from recording them, so _convert_torch_func never finds a match. The subsequent loop silently omits the dimension, leaving shape empty and tripping the “0-dimensional tensors are not supported” error for perfectly valid kernels. This regresses any workload that models dynamic extents with scalar parameters—which is how most of our TIR schedules are emitted.
Please keep those symbols in the map (or otherwise resolve them) and fail loudly if a tir.Var cannot be resolved. One fix is to drop the shape not in params filter and tighten the call site to raise when a symbol is still missing:
- for j, shape in enumerate(buffer.shape):
- if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and
- (shape not in params)):
- dynamic_symbolic_map[shape] = (0, i, j)
+ for j, shape in enumerate(buffer.shape):
+ if isinstance(shape, tir.Var) and shape not in dynamic_symbolic_map:
+ dynamic_symbolic_map[shape] = (0, i, j)
…
- for s in param_shapes[i]:
- if isinstance(s, tir.Var):
- for key in dynamic_symbolic_map:
- if (str(s) == str(key)):
- ref_id, ref_tensor_idx, ref_shape_idx = dynamic_symbolic_map[key]
- shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx])
- else:
- shape.append(s)
+ for s in param_shapes[i]:
+ if isinstance(s, tir.Var):
+ if s not in dynamic_symbolic_map:
+ raise ValueError(
+ f"{global_symbol}: cannot resolve dynamic dim '{s}' at output index {i}; "
+ "bind it to an input buffer or scalar."
+ )
+ ref_id, ref_tensor_idx, ref_shape_idx = dynamic_symbolic_map[s]
+ ICHECK_EQ(ref_id, 0) << "stride symbols not yet supported here";
+ shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx])
+ else:
+ shape.append(s)Without this, dynamic kernels misbehave immediately.
Also applies to: 217-236
- Changed default execution backend from "cython" to "auto" in multiple locations to allow automatic selection based on the target. - Expanded the list of supported execution backends to include "torch" and "nvrtc" across various classes and functions. - Enhanced backend resolution logic in `KernelCache` and `AutoTuner` to ensure appropriate backend selection based on the target. - Updated documentation to reflect changes in execution backend options and their defaults.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (5)
tilelang/jit/adapter/nvrtc/adapter.py (1)
172-180: Consider clarifyingget_kernel_sourcescope.Now that kernel sources are split into host and device,
get_kernel_source()returns onlydevice_kernel_source. The method name no longer fully describes its behavior.Options:
- Rename to
get_device_kernel_source()for clarity- Return both sources as a tuple
- Add a docstring note that it returns device source only
tilelang/cache/kernel_cache.py (4)
291-298: Consider usinglogging.exceptionfor better diagnostics.The error handling correctly catches exceptions, but as static analysis suggests, using
self.logger.exceptioninstead oferrorwould automatically include the traceback, making debugging easier.Apply this diff:
except Exception as e: - self.logger.error(f"Error saving kernel source code to disk: {e}") + self.logger.exception("Error saving kernel source code to disk")
302-314: Consider usinglogging.exceptionfor better diagnostics.Same as the previous segment—using
self.logger.exceptionwould provide more diagnostic context.Apply this diff:
except Exception as e: - self.logger.error(f"Error saving host kernel source code to disk: {e}") + self.logger.exception("Error saving host kernel source code to disk")
411-425: Consider usinglogging.exceptionfor better diagnostics.As with the save methods, using
self.logger.exceptionwould include tracebacks, making it easier to diagnose file I/O issues.Apply this diff:
except Exception as e: - self.logger.error(f"Error loading kernel source code from disk: {e}") + self.logger.exception("Error loading kernel source code from disk") try: # ... except Exception as e: - self.logger.error(f"Error loading host kernel source code from disk: {e}") + self.logger.exception("Error loading host kernel source code from disk")
451-452: Optional: Address the TODO for better diagnostics.The TODO comment suggests reporting why the kernel load failed. This would help users understand cache misses.
Would you like me to generate code that logs which of the three required components (host_kernel_source, device_kernel_source, kernel_params) failed to load?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
tilelang/autotuner/param.py(3 hunks)tilelang/autotuner/tuner.py(4 hunks)tilelang/cache/__init__.py(1 hunks)tilelang/cache/kernel_cache.py(14 hunks)tilelang/jit/__init__.py(9 hunks)tilelang/jit/adapter/nvrtc/adapter.py(5 hunks)tilelang/jit/execution_backend.py(1 hunks)tilelang/jit/kernel.py(16 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tilelang/autotuner/param.py
🧰 Additional context used
🧬 Code graph analysis (6)
tilelang/jit/adapter/nvrtc/adapter.py (3)
tilelang/jit/adapter/nvrtc/wrapper.py (3)
host_func(258-262)host_func(265-267)update_lib_code(498-555)tilelang/jit/adapter/wrapper.py (7)
host_func(525-535)wrap(134-135)wrap(818-837)wrap(845-859)update_lib_code(444-493)update_lib_code(731-750)update_lib_code(783-785)tilelang/jit/adapter/nvrtc/libgen.py (1)
NVRTCLibraryGenerator(40-235)
tilelang/cache/kernel_cache.py (8)
tilelang/jit/kernel.py (7)
JITKernel(31-756)out_idx(596-597)export_library(611-628)kernel_source(604-605)get_host_source(436-441)get_kernel_source(423-434)from_database(143-183)tilelang/utils/target.py (1)
determine_target(62-123)tilelang/jit/execution_backend.py (2)
resolve_execution_backend(62-100)allowed_backends_for_target(25-55)tilelang/jit/adapter/tvm_ffi.py (4)
func(195-262)get_host_source(299-303)get_kernel_source(311-316)from_database(267-297)tilelang/autotuner/param.py (2)
_load_kernel_from_disk(205-273)_save_kernel_to_disk(146-203)tilelang/jit/adapter/nvrtc/adapter.py (2)
get_kernel_source(172-180)from_database(95-148)tilelang/jit/adapter/ctypes/adapter.py (3)
get_kernel_source(294-300)libpath(280-282)from_database(113-162)tilelang/jit/adapter/base.py (1)
get_kernel_source(93-97)
tilelang/jit/kernel.py (7)
tilelang/jit/adapter/tvm_ffi.py (4)
TVMFFIKernelAdapter(23-321)from_database(267-297)get_kernel_source(311-316)get_host_source(299-303)tilelang/jit/adapter/torch/metal.py (1)
MetalKernelAdapter(14-75)tilelang/engine/param.py (1)
KernelParam(12-103)tilelang/jit/adapter/base.py (2)
BaseKernelAdapter(10-100)get_kernel_source(93-97)tilelang/jit/adapter/nvrtc/adapter.py (2)
from_database(95-148)get_kernel_source(172-180)tilelang/jit/adapter/ctypes/adapter.py (3)
from_database(113-162)CtypesKernelAdapter(19-300)get_kernel_source(294-300)tilelang/jit/adapter/cython/adapter.py (2)
from_database(149-208)get_kernel_source(384-391)
tilelang/jit/execution_backend.py (2)
tilelang/language/ast/ir.py (1)
target(1682-1713)tilelang/jit/__init__.py (3)
jit(293-294)jit(298-310)jit(313-382)
tilelang/jit/__init__.py (2)
tilelang/jit/execution_backend.py (2)
resolve_execution_backend(62-100)allowed_backends_for_target(25-55)tilelang/language/ast/ir.py (1)
target(1682-1713)
tilelang/autotuner/tuner.py (5)
tilelang/utils/target.py (1)
determine_target(62-123)tilelang/jit/__init__.py (3)
jit(293-294)jit(298-310)jit(313-382)tilelang/jit/execution_backend.py (1)
resolve_execution_backend(62-100)tilelang/autotuner/param.py (1)
CompileArgs(32-75)tilelang/jit/kernel.py (1)
out_idx(596-597)
🪛 Ruff (0.14.4)
tilelang/jit/adapter/nvrtc/adapter.py
104-104: Unused class method argument: pass_configs
(ARG003)
tilelang/cache/kernel_cache.py
297-297: Do not catch blind exception: Exception
(BLE001)
298-298: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
313-313: Do not catch blind exception: Exception
(BLE001)
314-314: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
416-416: Do not catch blind exception: Exception
(BLE001)
417-417: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
424-424: Do not catch blind exception: Exception
(BLE001)
425-425: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
tilelang/jit/execution_backend.py
51-53: try-except-pass detected, consider logging the exception
(S110)
51-51: Do not catch blind exception: Exception
(BLE001)
90-92: Avoid specifying long messages outside the exception class
(TRY003)
96-98: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (4)
tilelang/jit/adapter/nvrtc/adapter.py (1)
30-46: LGTM: Clean refactoring todevice_kernel_source.The rename from
kernel_global_sourcetodevice_kernel_sourceclarifies intent and aligns with the host/device kernel separation introduced in this PR.Also applies to: 77-77, 80-80
tilelang/cache/kernel_cache.py (3)
16-16: LGTM: Import and constants for TVM FFI support.The new
Executableimport and path constants (DEVICE_KERNEL_PATH,HOST_KERNEL_PATH,EXECUTABLE_PATH) properly support the TVM FFI backend's separation of device and host kernel sources.Also applies to: 22-24
140-157: LGTM: Backend resolution and target normalization.The new logic properly normalizes the target and resolves the execution backend before caching. The conditional logging avoids duplicates when the backend is explicitly specified and matches the resolved value.
259-263: LGTM: Atomic executable write.The new
_safe_write_executablemethod correctly implements atomic writes using a temporary file andos.replace, consistent with the existing_safe_write_filepattern.
| _lock = threading.Lock() # For thread safety | ||
| _memory_cache = {} # In-memory cache dictionary | ||
| execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython" | ||
| execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Race condition on shared execution_backend class variable.
The class variable self.execution_backend is modified by concurrent threads (line 94) and read without synchronization in save/load methods (lines 305, 319, 329, 338, 395). Since the lock only protects cache access but not compilation, two threads compiling kernels with different backends can corrupt the cache.
Example scenario:
- Thread A compiles with backend="ctypes", sets
self.execution_backend="ctypes" - Thread B compiles with backend="tvm_ffi", overwrites
self.execution_backend="tvm_ffi" - Thread A saves its ctypes kernel using tvm_ffi paths → wrong library type saved
Apply this diff to use the local execution_backend parameter instead:
def _save_kernel_to_disk(self,
key: str,
kernel: JITKernel,
+ execution_backend: str,
func: Callable = None,
verbose: bool = False):- if self.execution_backend == "tvm_ffi":
+ if execution_backend == "tvm_ffi":
KernelCache._safe_write_file(
host_kernel_path, "w",
lambda file: file.write(kernel.adapter.get_host_source()))- if self.execution_backend == "nvrtc":
+ if execution_backend == "nvrtc":
kernel_lib_path = KERNEL_CUBIN_PATH
- elif self.execution_backend == "tvm_ffi":
+ elif execution_backend == "tvm_ffi":
kernel_lib_path = EXECUTABLE_PATH
else:- if self.execution_backend == "nvrtc":
+ if execution_backend == "nvrtc":
src_lib_path = kernel.adapter.libpath
# ... nvrtc handling ...
- elif self.execution_backend == "tvm_ffi":
+ elif execution_backend == "tvm_ffi":
executable = kernel.adapter.executable
# ... tvm_ffi handling ...- if self.execution_backend == "nvrtc":
+ if execution_backend == "nvrtc":
kernel_lib_path = KERNEL_CUBIN_PATH
- elif self.execution_backend == "tvm_ffi":
+ elif execution_backend == "tvm_ffi":
kernel_lib_path = EXECUTABLE_PATHUpdate the call site at line 217:
- self._save_kernel_to_disk(key, kernel, func, verbose)
+ self._save_kernel_to_disk(key, kernel, execution_backend, func, verbose)Also remove the assignment at line 94 in _generate_key:
- self.execution_backend = execution_backendAlso applies to: 94-94, 305-312, 319-324, 329-349, 395-400
| elif self.execution_backend == "tvm_ffi": | ||
| executable = kernel.adapter.executable | ||
| if verbose: | ||
| self.logger.debug(f"Saving kernel executable to file: {executable}") | ||
| KernelCache._safe_write_executable(executable, kernel_lib_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix: Log the file path, not the object.
Line 341 logs the executable object rather than the destination path, which will produce unhelpful output like <tvm.runtime.Module object>.
Apply this diff:
elif self.execution_backend == "tvm_ffi":
executable = kernel.adapter.executable
if verbose:
- self.logger.debug(f"Saving kernel executable to file: {executable}")
+ self.logger.debug(f"Saving kernel executable to file: {kernel_lib_path}")
KernelCache._safe_write_executable(executable, kernel_lib_path)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| elif self.execution_backend == "tvm_ffi": | |
| executable = kernel.adapter.executable | |
| if verbose: | |
| self.logger.debug(f"Saving kernel executable to file: {executable}") | |
| KernelCache._safe_write_executable(executable, kernel_lib_path) | |
| elif self.execution_backend == "tvm_ffi": | |
| executable = kernel.adapter.executable | |
| if verbose: | |
| self.logger.debug(f"Saving kernel executable to file: {kernel_lib_path}") | |
| KernelCache._safe_write_executable(executable, kernel_lib_path) |
🤖 Prompt for AI Agents
In tilelang/cache/kernel_cache.py around lines 338 to 342, the debug log
currently interpolates and prints the `executable` object (e.g.
`<tvm.runtime.Module object>`) instead of the destination file path; change the
log to reference the `kernel_lib_path` (or include both `kernel_lib_path` and a
short description) so the debug output shows the file being written to, then
call KernelCache._safe_write_executable as before.
| host_kernel_source: str, | ||
| device_kernel_source: str, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused parameters from from_database.
Three parameters are accepted but never used:
host_kernel_source(line 100): stored at line 109 but never referenced afterwardpass_configs(line 104): never stored or used, confirmed by static analysis
In contrast, __init__ actively uses pass_configs (line 74: self.wrapper.assign_pass_configs(pass_configs)). Since from_database loads from a prebuilt library and doesn't instantiate a wrapper, these parameters serve no purpose.
If these are retained for API symmetry or future extension, consider adding a comment explaining why. Otherwise, remove them:
@classmethod
def from_database(cls,
params: list[KernelParam],
result_idx: list[int],
target: str,
func_or_mod: tir.PrimFunc | tvm.IRModule,
- host_kernel_source: str,
device_kernel_source: str,
kernel_lib_path: str,
verbose: bool = False,
- pass_configs: dict[str, Any] | None = None,
compile_flags: list[str] | None = None):And remove the corresponding storage at line 109:
adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx)
- adapter.host_kernel_source = host_kernel_source
adapter.device_kernel_source = device_kernel_sourceAlso applies to: 104-104
🤖 Prompt for AI Agents
In tilelang/jit/adapter/nvrtc/adapter.py around lines 100-101, from_database
currently accepts unused parameters host_kernel_source and pass_configs
(host_kernel_source is stored at line 109 but never used; pass_configs is
neither stored nor used). Remove these two parameters from the from_database
signature, delete the assignment that stores host_kernel_source at line 109, and
remove any no-op handling of pass_configs; if API symmetry or future use is
desired instead of removal, add a brief comment next to the signature explaining
they are intentionally unused. Also update any call sites to stop passing those
arguments.
| from tilelang.jit.adapter.nvrtc import is_nvrtc_available # lazy | ||
| if not is_nvrtc_available and "nvrtc" in allowed: | ||
| allowed = [b for b in allowed if b != "nvrtc"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call is_nvrtc_available() when filtering NVRTC
is_nvrtc_available is a callable; without invoking it we never drop NVRTC from allowed, so resolve_execution_backend will happily return "nvrtc" even when that backend can't run and later compilation will crash. Please call the helper before negating it.
- from tilelang.jit.adapter.nvrtc import is_nvrtc_available # lazy
- if not is_nvrtc_available and "nvrtc" in allowed:
+ from tilelang.jit.adapter.nvrtc import is_nvrtc_available # lazy
+ if not is_nvrtc_available() and "nvrtc" in allowed:🤖 Prompt for AI Agents
In tilelang/jit/execution_backend.py around lines 48 to 50, the code checks
is_nvrtc_available without calling it, so NVRTC is never filtered out; change
the condition to call the helper (i.e., use is_nvrtc_available()) and then
remove "nvrtc" from allowed when it returns False — keep the lazy import but
invoke the callable in the if-statement and filter the allowed list accordingly.
|
A few benchmark results: GEMM
Average results: MLAtvm_ffi: 20.63 µs |
Hzfengsy
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many unexpected changes need to be fixed
- Updated `ExtractFuncInfo` in `rt_mod_cuda.cc` and `rt_mod_hip.cc` to map boolean argument types to int32, ensuring compatibility with device runtime. - Refactored `BindDLTensor` in `arg_binder.cc` to improve null handling and validation checks for DLTensor parameters, utilizing expression-level guards to prevent dereferencing null pointers. - Enhanced error checking for buffer shape, strides, and data fields, ensuring robust handling of optional inputs and maintaining consistency across various checks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/jit/kernel.py (1)
606-607: Fix AttributeError: changekernel_global_sourcetodevice_kernel_sourceThe
kernel_sourceproperty at line 607 accessesadapter.kernel_global_source, which does not exist on any adapter currently used by this code (ctypes, cython, nvrtc, tvm_ffi). All these adapters havedevice_kernel_sourceinstead. This will raise anAttributeErrorwhenartifactisNone.@property def kernel_source(self) -> str: - return self.artifact.kernel_source if self.artifact else self.adapter.kernel_global_source + return self.artifact.kernel_source if self.artifact else self.adapter.device_kernel_source
♻️ Duplicate comments (2)
tilelang/jit/adapter/nvrtc/adapter.py (1)
94-148: Unused parameters remain infrom_database.As noted in the previous review,
host_kernel_source(line 100) is stored at line 109 but never referenced afterward, andpass_configs(line 104) is neither stored nor used. This remains unresolved.src/transform/arg_binder.cc (1)
280-301: Review overly permissive bool compatibility logic.The
bit1_okcheck at lines 299-300 accepts any dtype with 1-bit width regardless of type code, which could mask type errors by accepting exotic 1-bit types without proper bool semantics. The explicitly allowed type codes (int8, uint8, kDLBool with bits 8 or 1) should be sufficient.As noted in the previous review, consider removing the overly broad
bit1_okcheck:PrimExpr kdlbool1_ok = (v_type_code == code_kdlbool && v_type_bits == bits1 && lanes_ok); - // Also accept any dtype whose bitwidth=1, regardless of code, to be - // defensive. - PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok); - cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok; + cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok;
🧹 Nitpick comments (2)
src/transform/inject_assumes.cc (1)
67-71: Consider adding observability when skipping assumptions to aid debugging.The guard correctly prevents injecting assumptions that reference undefined variables, which would create invalid TIR. However, silently skipping with
continuemakes it difficult to debug when assumptions are unexpectedly omitted during the compilation pipeline.Adding a debug log when assumptions are skipped would help developers understand why certain assumptions aren't injected and whether this affects downstream optimizations:
if (!tvm::tir::UndefinedVars(simplified).empty()) { DLOG(INFO) << "Skipping assumption injection for expression " << e.expr << " due to undefined variables in simplified form: " << simplified; continue; }Also verify that this behavior change (silently skipping assumptions with undefined variables) doesn't negatively impact optimization passes or correctness guarantees that depend on these assumptions being consistently injected.
src/transform/make_packed_api.cc (1)
431-448: Add logging for undefined variable fallback to improve observability during dynamic shape debugging.The fallback behavior binding undefined variables to zero is intentional and well-documented. However, LOG(WARNING) is a standard pattern used throughout src/transform/ for similar fallback scenarios. Adding logging here would help detect and debug unexpected undefined variables when dynamic shapes or injected assumptions cause this code path to execute.
if (ffi::Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params); undefined.size() != 0) { + LOG(WARNING) << "Function " << name_hint << " has " << undefined.size() + << " undefined variables; binding to zero as fallback"; for (const Var &v : undefined) { DataType t = v->dtype; PrimExpr zero = make_zero(t); seq_init.push_back(LetStmt(v, zero, nop)); } }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
src/target/rt_mod_cuda.cc(1 hunks)src/target/rt_mod_hip.cc(1 hunks)src/transform/arg_binder.cc(9 hunks)src/transform/inject_assumes.cc(2 hunks)src/transform/make_packed_api.cc(13 hunks)tilelang/engine/phase.py(1 hunks)tilelang/jit/adapter/nvrtc/adapter.py(5 hunks)tilelang/jit/kernel.py(16 hunks)
✅ Files skipped from review due to trivial changes (1)
- tilelang/engine/phase.py
🧰 Additional context used
🧬 Code graph analysis (4)
src/transform/arg_binder.cc (1)
tilelang/language/ast/ir.py (3)
LetStmt(880-908)handle(1467-1497)decl_buffer(1137-1205)
tilelang/jit/adapter/nvrtc/adapter.py (6)
tilelang/jit/adapter/nvrtc/wrapper.py (3)
host_func(258-262)host_func(265-267)update_lib_code(498-555)tilelang/jit/adapter/wrapper.py (7)
host_func(525-535)wrap(134-135)wrap(818-837)wrap(845-859)update_lib_code(444-493)update_lib_code(731-750)update_lib_code(783-785)tilelang/jit/adapter/nvrtc/libgen.py (1)
NVRTCLibraryGenerator(40-235)tilelang/jit/kernel.py (2)
params(602-603)get_kernel_source(423-434)tilelang/jit/adapter/base.py (1)
get_kernel_source(93-97)tilelang/jit/adapter/ctypes/adapter.py (1)
get_kernel_source(294-300)
src/transform/make_packed_api.cc (2)
src/target/codegen_cpp.cc (8)
VisitExpr_(166-179)VisitExpr_(166-167)VisitExpr_(375-414)VisitExpr_(375-376)VisitExpr_(452-455)VisitExpr_(452-453)VisitExpr_(457-460)VisitExpr_(457-458)tilelang/language/tir/op.py (1)
tvm_struct_get(414-436)
tilelang/jit/kernel.py (5)
tilelang/jit/adapter/tvm_ffi.py (4)
TVMFFIKernelAdapter(23-321)from_database(267-297)get_kernel_source(311-316)get_host_source(299-303)tilelang/jit/adapter/base.py (2)
BaseKernelAdapter(10-100)get_kernel_source(93-97)tilelang/jit/adapter/nvrtc/adapter.py (2)
from_database(95-148)get_kernel_source(172-183)tilelang/jit/adapter/ctypes/adapter.py (3)
from_database(113-162)CtypesKernelAdapter(19-300)get_kernel_source(294-300)tilelang/jit/adapter/cython/adapter.py (2)
from_database(149-208)get_kernel_source(384-391)
🪛 Ruff (0.14.4)
tilelang/jit/adapter/nvrtc/adapter.py
104-104: Unused class method argument: pass_configs
(ARG003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
- GitHub Check: Build SDist
🔇 Additional comments (20)
src/transform/inject_assumes.cc (1)
7-7: LGTM!The include is necessary for the
UndefinedVarsfunction used in the new guard logic below.tilelang/jit/adapter/nvrtc/adapter.py (2)
30-92: LGTM!The rename from
kernel_global_sourcetodevice_kernel_sourceis consistent with the broader refactoring to separate host and device kernel sources. All usages are correctly updated.
172-183: LGTM!The
kernel_onlyparameter correctly returnsdevice_kernel_sourcewhenTrueandhost_funcwhenFalse, aligning with the refactoring pattern across other adapters.tilelang/jit/kernel.py (5)
18-18: LGTM!The addition of
TVMFFIKernelAdapterand the default backend change to"tvm_ffi"align with the PR objectives. According to the benchmark results, tvm_ffi shows 1.41-2.10× speedup over other backends.Also applies to: 54-110
142-183: LGTM!The signature changes to accept
host_kernel_sourceanddevice_kernel_source(replacingkernel_global_source) are consistent with the adapter refactoring across the codebase.
228-229: LGTM!The logic correctly enables host codegen and device compile for the
tvm_ffibackend, and all adapter instantiations properly usedevice_kernel_source. TheTVMFFIKernelAdapterreceives all required parameters includingrt_mod,host_mod,device_mod, anddevice_kernel_source.Note: Line 307 still uses
kernel_global_sourcefor the Metal adapter, which may be intentional if that adapter hasn't been refactored yet.Also applies to: 243-316
318-384: LGTM!The
_create_adapter_from_databasemethod correctly handles thetvm_ffibackend and consistently passeshost_kernel_sourceanddevice_kernel_sourceto all adapter constructors.
423-443: LGTM!The
get_kernel_sourceandget_host_sourcemethods correctly includetvm_ffiin the backend checks and delegate to the adapter methods. The assertion at line 442 is good defensive programming.src/target/rt_mod_hip.cc (1)
38-42: LGTM!The boolean-to-int32 mapping for device runtime arguments is correctly implemented and well-documented. This aligns with the identical change in
rt_mod_cuda.ccand addresses device runtime limitations with bool types.src/target/rt_mod_cuda.cc (1)
27-31: LGTM!The boolean parameter handling is correctly implemented and consistent with the HIP version in
rt_mod_hip.cc.src/transform/arg_binder.cc (4)
99-122: LGTM!The relaxed dtype compatibility logic in
BindBufferis well-structured and appropriately handles:
- FP8 E4M3 variants binding together for mixed precision support
- FP8 E5M2 variants compatibility
- Bool binding from int8/uint8 with matching lanes
The error message at lines 121-122 is also improved to clearly show both expected and actual dtypes.
192-230: LGTM!The NULL DLTensor* handling is correctly implemented:
- The
is_nullcheck is cached in a Let-bound variable (lines 196-200) to avoid repeated evaluations- All dereferences are guarded with
if_then_elseto return safe defaults when handle is NULL- Runtime assertions use
Or(is_null, ...)to short-circuit validation when the handle is NULLThis enables optional tensor inputs while maintaining safety and clear error messages.
238-255: LGTM!The dtype field validation correctly guards all DLTensor field accesses with
if_then_else(Not(is_null), ...)to avoid dereferencing NULL handles. Default values are returned when the handle is NULL, ensuring safe evaluation.
310-523: LGTM!The NULL-guarded field handling throughout the rest of
BindDLTensoris consistently and correctly implemented:
- All DLTensor field accesses use
if_then_else(Not(is_null), ..., safe_default)to avoid dereferencing NULL- All assertions use
Or(is_null, ...)to short-circuit validation when handle is NULL- The data pointer NULL check (lines 509-514) correctly exempts size-0 arrays and NULL handles
This comprehensive NULL safety enables optional tensor inputs while maintaining validation for non-NULL cases.
src/transform/make_packed_api.cc (6)
52-128: LGTM!The
ReturnRewriterrefactor correctly transitions from the legacy dual-buffer approach to FFI-centric struct packing:
- Constructor simplified to single
ret_varparameter (line 52)ConvertForFFIproperly handles bool by mapping tokTVMFFIBooland casting to Int(64) (lines 89-92)WriteToOutcorrectly packs three fields (type_index,zero_padding,union_value) usingtvm_struct_set(lines 111-125)The implementation aligns with the TVM FFI conventions and handles all relevant primitive types.
136-177: LGTM!The
SubroutineCallRewritercorrectly migrates to FFI-centric types (ffi::Map,ffi::String,ffi::Array,ffi::GetRef,ffi::Optional) without altering the core logic. The changes are consistent and maintain the existing functionality.
247-277: LGTM!The new packed API signature correctly defines the FFI-style parameters:
v_self_handlefor module contextv_packed_argsfor packed argument arrayv_num_packed_argsfor argument countv_resultfor return value storageThe
f_load_arg_valuelambda correctly loads arguments from thekTVMFFIAnyUnionValuefield and handles type conversion viaAPIType.
290-293: LGTM!The NULL pointer check for
v_packed_argsis appropriately guarded bynum_args > 0, avoiding unnecessary assertions when the function takes no arguments.
300-371: LGTM!The argument type checking and loading logic is correctly implemented:
- Type index extraction from
kTVMFFIAnyTypeIndexfield (lines 304-310)- Handle type properly adjusts Tensor offsets by 24 bytes to account for
TVMFFIObjectheader (lines 324-331)- Bool/int/float type compatibility is appropriately relaxed (bool accepts int, int accepts bool, float accepts int/bool with runtime conversion via
Select)- Deferred binding pattern (lines 366-371) correctly avoids use-before-def issues
The implementation aligns with FFI conventions and maintains backward compatibility.
374-420: LGTM!The final function construction is correct:
- Signature with 4 FFI-style parameters (lines 374-377)
- Proper binding order: variables before buffers (lines 386-394)
- Global symbol prefixed with
ffi::symbol::tvm_ffi_symbol_prefix(lines 395-401)ReturnRewritercorrectly invoked withv_result(line 403)- Device context handling preserved (lines 407-420)
| if execution_backend == "tvm_ffi": | ||
| # Use TVMFFIKernelAdapter for interoperability with PyTorch via DLPack. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update outdated comment.
The comment mentions "DLPack" but this PR introduces TVM FFI as the execution backend. Update the comment to reflect the current implementation.
Apply this diff:
- if execution_backend == "tvm_ffi":
- # Use TVMFFIKernelAdapter for interoperability with PyTorch via DLPack.
+ if execution_backend == "tvm_ffi":
+ # Use TVMFFIKernelAdapter with TVM FFI for efficient kernel execution.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if execution_backend == "tvm_ffi": | |
| # Use TVMFFIKernelAdapter for interoperability with PyTorch via DLPack. | |
| if execution_backend == "tvm_ffi": | |
| # Use TVMFFIKernelAdapter with TVM FFI for efficient kernel execution. |
🤖 Prompt for AI Agents
In tilelang/jit/kernel.py around lines 241 to 242, the inline comment
incorrectly references "DLPack" while the code path is for the "tvm_ffi"
execution backend; update the comment to describe that this branch uses the TVM
FFI adapter for interoperability (e.g., with frameworks like PyTorch) via TVM's
FFI interface rather than DLPack, and replace the outdated wording with a
concise comment reflecting the current TVM FFI implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/transform/make_packed_api.cc(13 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/make_packed_api.cc (2)
src/target/codegen_cpp.cc (8)
VisitExpr_(166-179)VisitExpr_(166-167)VisitExpr_(375-414)VisitExpr_(375-376)VisitExpr_(452-455)VisitExpr_(452-453)VisitExpr_(457-460)VisitExpr_(457-458)tilelang/language/tir/op.py (1)
tvm_struct_get(414-436)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build SDist
🔇 Additional comments (13)
src/transform/make_packed_api.cc (13)
23-23: LGTM!The new includes support the FFI migration. The
ffi/extra/module.hheader provides FFI infrastructure, andunordered_setis a standard addition for variable tracking.Also applies to: 36-36
52-52: LGTM!The simplified constructor signature is consistent with the FFI migration, where a single result variable points to the FFI return structure.
89-92: LGTM!The boolean type handling follows the established pattern for other primitive types, correctly mapping to
kTVMFFIBooland converting to Int(64) for FFI representation.
109-128: LGTM!The migration from buffer-based to FFI struct-based return handling is implemented correctly. The method properly initializes all FFI structure fields (type_index, zero_padding, value) in the correct sequence.
136-175: LGTM!The systematic migration to
ffi::namespace types (Map, Optional, String, Array, GetRef) is consistent and maintains the existing logic while aligning with the FFI infrastructure.
181-188: LGTM!The simplified helper functions provide clean, type-safe assertion creation with string messages. The null-pointer check correctly uses the
builtin::isnullptr()intrinsic.
247-250: LGTM!The new variables align with FFI calling conventions.
v_self_handleprovides the self-reference handle, andv_resultis correctly typed as a void pointer for writing the FFI return structure.
290-293: LGTM!The defensive null check for
v_packed_argsprevents potential null pointer dereferences when processing arguments. The conditional check (only whennum_args > 0) is an appropriate optimization.
332-365: Verify the type coercion policy aligns with requirements.The implementation allows flexible type conversions:
- Boolean parameters accept both
kTVMFFIBoolandkTVMFFIInt- Integer parameters accept both
kTVMFFIIntandkTVMFFIBool- Float parameters accept
kTVMFFIFloat,kTVMFFIInt, andkTVMFFIBoolWhile this provides compatibility for dynamic language interop, it reduces type safety and could mask errors where incorrect types are passed. Ensure this flexibility is intentional and documented as part of the FFI calling convention.
374-377: LGTM!The function signature correctly implements the FFI packed calling convention with the standard four parameters: self handle, packed arguments, argument count, and result pointer. The comment clearly documents the signature.
390-401: LGTM!The buffer binding correctly handles DLTensor parameters with type checking, and the function attributes are properly updated for FFI conventions. The symbol prefix ensures correct FFI naming.
447-448: LGTM!The buffer map is correctly cleared since the FFI calling convention uses packed arguments instead of buffer parameters. Setting
ret_typetoInt(32)aligns with the packed function error code convention (returning 0 on success).
324-326: Code is correct; TVMFFIObject size assertion is properly designed.The static_assert at line 325 is the correct approach for cross-platform safety. TVMFFIObject uses a stable, documented C ABI header with a 24-byte intrusive header on common 64-bit platforms, and the assertion will fail at compile-time if the assumption is incorrect on any target platform. No changes needed.
| auto f_load_arg_value = [&](DataType arg_type, int i) { | ||
| ffi::Array<PrimExpr> call_args{ | ||
| v_packed_args, IntImm(DataType::Int(32), i), | ||
| IntImm(DataType::Int(32), builtin::kTVMValueContent)}; | ||
| IntImm(DataType::Int(32), builtin::kTVMFFIAnyUnionValue)}; | ||
| // load 64 bit version | ||
| DataType api_type = APIType(t); | ||
| DataType api_type = APIType(arg_type); | ||
| PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args); | ||
| // cast to the target version. | ||
| if (api_type != t) { | ||
| res = Cast(t, res); | ||
| if (api_type != arg_type) { | ||
| res = Cast(arg_type, res); | ||
| } | ||
| return res; | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify APIType is available.
The FFI argument loading logic is correct. However, ensure that the APIType function (line 269) is properly defined in the included headers.
Run the following script to verify APIType is defined:
🏁 Script executed:
#!/bin/bash
# Description: Verify APIType function is defined in the codebase
# Search for APIType definition
rg -nP '\bAPIType\s*\(' --type=cpp -C3Length of output: 588
🏁 Script executed:
#!/bin/bash
# Search for APIType definition (not just calls)
rg -nP 'DataType\s+APIType\s*\(' --type=cpp --type=h
echo "---"
# Check includes at top of make_packed_api.cc
head -50 src/transform/make_packed_api.cc | grep -E '#include|namespace'Length of output: 620
🏁 Script executed:
#!/bin/bash
# Search in arg_binder.h specifically
echo "=== Searching arg_binder.h ==="
rg -nP 'APIType' src/transform/arg_binder.h
echo ""
echo "=== Broader search for APIType definition ==="
# Look for APIType in any context
rg -nP 'APIType' --type=cpp --type=h | grep -E 'inline|DataType|^[^:]*:.*APIType\s*='Length of output: 157
🏁 Script executed:
#!/bin/bash
# Check entire make_packed_api.cc for APIType definition or macro
echo "=== All APIType occurrences in file ==="
rg -n 'APIType' src/transform/make_packed_api.cc
echo ""
echo "=== Search for #define APIType ==="
rg -n '#define.*APIType' --type=cpp --type=h
echo ""
echo "=== Check full context of make_packed_api.cc around line 269 ==="
sed -n '1,100p' src/transform/make_packed_api.cc | tail -40Length of output: 1493
🏁 Script executed:
#!/bin/bash
# Search for APIType in all header files (especially tvm includes)
echo "=== Search all headers for APIType ==="
find . -name "*.h" -type f | xargs grep -l 'APIType' 2>/dev/null
echo ""
echo "=== Direct search in builtin.h ==="
rg -n 'APIType' ./include/tvm/tir/builtin.h 2>/dev/null || echo "File not found, searching differently..."
echo ""
echo "=== All APIType matches across entire repo ==="
rg -r 'APIType' --type=cpp --type=h . 2>/dev/null | head -20Length of output: 6300
🏁 Script executed:
#!/bin/bash
# Get broader context around the lambda function
echo "=== Lines 250-290 of make_packed_api.cc ==="
sed -n '250,290p' src/transform/make_packed_api.cc
echo ""
echo "=== Search for similar type conversion patterns ==="
rg -n 'Cast\(.*DataType' src/transform/make_packed_api.cc | head -10Length of output: 1996
🏁 Script executed:
#!/bin/bash
# Final confirmation - check if APIType might be in global namespace or as extern
echo "=== Check for extern or forward declaration of APIType ==="
rg -n 'extern.*APIType|forward.*APIType|APIType.*extern' --type=cpp --type=h
echo ""
echo "=== Search in entire repo for any APIType definition pattern ==="
rg -n 'APIType\s*=' --type=cpp --type=h
echo ""
echo "=== Check namespace around make_packed_api.cc ==="
sed -n '1,50p' src/transform/make_packed_api.cc | tail -20Length of output: 578
APIType function is undefined and will cause a compilation error.
The function APIType(arg_type) called at line 269 is not defined anywhere in the codebase. A comprehensive search found no definition in headers, included files, or via macro. This will prevent the code from compiling.
You need to implement the APIType function or import it from an external library. Based on the usage pattern (converting a target DataType to its FFI/API representation), it likely belongs in arg_binder.h or a related utilities header.
🤖 Prompt for AI Agents
In src/transform/make_packed_api.cc around lines 264 to 276, the call to
APIType(arg_type) is undefined and will cause a compile error; implement APIType
to map a target DataType to its FFI/API representation (e.g., promote
integer/float bitwidths to the 64-bit FFI-backed type and preserve
pointer/handle types), add its declaration to an appropriate header (suggest
arg_binder.h or a new util header) and include that header in
make_packed_api.cc, or alternatively import the correct existing header if the
function already exists elsewhere; ensure the function signature accepts
(DataType) and returns DataType, and handle all DataType kinds used here
(int/uint/float of various widths and any union/handle types) so Caller can
safely cast back to arg_type after Call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
src/transform/arg_binder.cc (1)
286-301: Overly permissive 1-bit dtype check remains from previous review.The
bit1_okcheck at lines 297-300 accepts any dtype with 1-bit width regardless of type code, which could mask type errors by allowing exotic 1-bit types without proper bool semantics. This concern was raised in the previous review.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/transform/arg_binder.cc(9 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/arg_binder.cc (1)
tilelang/language/ast/ir.py (3)
LetStmt(880-908)handle(1467-1497)decl_buffer(1137-1205)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build SDist
🔇 Additional comments (2)
src/transform/arg_binder.cc (2)
99-122: LGTM—dtype compatibility relaxation is well-structured.The
dtype_compatiblelambda appropriately relaxes dtype checks to support:
- FP8 E4M3 variants binding together (e4m3, e4m3fn, e4m3fnuz)
- FP8 E5M2 variants (e5m2, e5m2fnuz)
- Bool binding from int8/uint8 with matching lanes
This defensive binding improves interoperability for mixed-precision and cross-framework usage. Error messages clearly state expected vs. actual dtype.
192-200: Efficient NULL guard pattern—caches check in a Let-bound variable.Binding the NULL check to
is_null_varprevents redundantisnullptr(handle)calls throughout the validation logic. This improves generated code quality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/transform/make_packed_api.cc (1)
312-331: Update comment: actual offset is 24 bytes, not 16 bytes.The comment on line 322 incorrectly states the offset "always equals 16 bytes," but the code applies a 24-byte offset (line 325:
static_assert(object_cell_offset == 24)). Update the comment to match the actual offset value.Suggested fix:
// if type_index is Tensor, we need to add the offset of the TVMFFIObject // header which always equals 24 bytes, this ensures that T.handle always // shows up as a DLTensor*
♻️ Duplicate comments (2)
src/transform/make_packed_api.cc (2)
264-276: Critical:APITypefunction is undefined.This issue was previously flagged. The
APIType(arg_type)call at line 269 will cause a compilation error as the function is not defined anywhere in the codebase.
431-433: Reconsider commenting out undefined variable checks.This issue was previously flagged. Disabling the undefined variable check could mask genuine bugs and reduce type safety. Consider adding a diagnostic warning or making this behavior configurable rather than silently allowing undefined variables.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/transform/make_packed_api.cc(13 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/make_packed_api.cc (4)
src/transform/storage_rewrite.cc (9)
dtype(712-718)dtype(712-712)func_ptr(1873-1910)func_ptr(1873-1873)i(720-860)var(1153-1176)var(1153-1153)buffer(1425-1458)buffer(1426-1427)src/target/codegen_cpp.cc (8)
VisitExpr_(166-179)VisitExpr_(166-167)VisitExpr_(375-414)VisitExpr_(375-376)VisitExpr_(452-455)VisitExpr_(452-453)VisitExpr_(457-460)VisitExpr_(457-458)src/target/codegen_cuda.cc (4)
VisitExpr_(902-1149)VisitExpr_(902-902)VisitExpr_(1151-1172)VisitExpr_(1151-1151)tilelang/language/tir/op.py (1)
tvm_struct_get(414-436)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build SDist
🔇 Additional comments (13)
src/transform/make_packed_api.cc (13)
52-52: LGTM: Simplified constructor signature.The removal of the
ret_tcodeparameter aligns with the FFI-centric approach where type information is now embedded in the FFI structure.
89-92: LGTM: Boolean support added correctly.The boolean type is properly mapped to
kTVMFFIBooland cast toInt(64)for FFI storage, consistent with the handling of other primitive types.
109-128: LGTM: FFI structure packing implementation.The
WriteToOutmethod correctly implements the FFI return value protocol by setting the type index, zero padding, and union value fields usingtvm_struct_setcalls.
136-177: LGTM: Consistent FFI type migration.The
SubroutineCallRewriterclass has been properly updated to use FFI-namespaced types (ffi::Map,ffi::String,ffi::Optional,ffi::Array,ffi::GetRef), maintaining consistency with the overall FFI migration.
181-188: LGTM: Helper functions properly defined.Both
MakeAssertEQandMakeAssertNotNullare correctly implemented with appropriate inline specifiers and proper use of TIR primitives.
247-250: LGTM: FFI signature variables added.The addition of
v_self_handleand the updated type forv_resultcorrectly reflect the new FFI function signature requirements.
290-293: LGTM: Defensive null pointer check added.The null check for
v_packed_argsis appropriate and only performed whennum_args > 0, improving error messages for null pointer issues.
332-365: LGTM: Comprehensive type checking and coercion logic.The argument processing correctly handles boolean, integer, and float types with appropriate type index checks and conversions. The float handling is particularly flexible, allowing implicit conversion from integer or boolean when needed.
374-377: LGTM: Function signature updated for FFI.The new signature correctly reflects the FFI calling convention with self handle, packed arguments array, argument count, and result pointer.
390-394: LGTM: Buffer binding logic preserved.The buffer binding correctly uses
BindDLTensorand declares buffers with the updated FFI-based argument handling.
395-402: LGTM: Global symbol prefixing for FFI registration.The function attributes are correctly updated with the FFI symbol prefix, appropriate calling convention, and target host, aligning with the FFI registration requirements.
403-403: LGTM: ReturnRewriter instantiation updated.The call correctly uses the simplified constructor signature with only the result variable.
435-436: LGTM: Function metadata correctly finalized.The buffer map is properly cleared and the return type is correctly set to
Int(32)for error code returns, consistent with the packed API convention.
| #include <tvm/tir/stmt_functor.h> | ||
| #include <tvm/tir/transform.h> | ||
|
|
||
| #include <unordered_set> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify that <unordered_set> is required.
The include was added but it's not immediately clear from the visible changes whether std::unordered_set is actually used in this file.
Run the following script to confirm usage:
🏁 Script executed:
#!/bin/bash
# Description: Check if unordered_set is used in make_packed_api.cc
rg -n '\bunordered_set\b' src/transform/make_packed_api.ccLength of output: 27
Remove unused <unordered_set> include at line 36.
The header is included but std::unordered_set is not used anywhere in the file. Removing this unused include will clean up unnecessary dependencies.
🤖 Prompt for AI Agents
In src/transform/make_packed_api.cc at line 36, the include statement for
<unordered_set> is unused; remove the #include <unordered_set> line to eliminate
the unnecessary dependency and rebuild to ensure no missing symbols.
- Improved null handling and validation checks in `BindDLTensor`, ensuring safe dereferencing of pointers. - Enhanced consistency checks for buffer shape, strides, and data fields, utilizing expression-level guards. - Updated `MakePackedAPI` to maintain code clarity and consistency in argument handling. - Minor adjustments in test files to streamline kernel execution and improve readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
src/transform/make_packed_api.cc (1)
264-276: Re‑check thatAPITypeis defined and included for FFI loads.
f_load_arg_valuedepends onAPIType(arg_type)(line 269) to map parameter dtypes to the 64‑bit FFI API dtype before callingtvm_struct_get. In a previous revision,APITypewas not defined anywhere and caused a compile‑time issue; the call site is still present here.Please verify that:
APIType(DataType)is now implemented (or imported) with the expected signature and behavior, and- The appropriate header is included so this call compiles.
If it isn’t, you’ll need to either define
APIType(likely alongside other API‑level dtype utilities) or inline the mapping logic used during loading.#!/bin/bash # Verify that APIType(DataType) is defined and visible to this TU rg -nP '\bAPIType\s*\(' --type=cpp --type=h -C3 echo "----" rg -nP 'DataType\s+APIType\s*\(' --type=cpp --type=h -C3src/transform/arg_binder.cc (1)
203-205: Null-handle path still dereferences the DLTensorThe new NULL-handling still calls
TVMArrayGet/BufferLoadbefore checkingis_null, e.g.PrimExpr v_ndim = TVMArrayGet(...)(Line 203),raw_shape_val = BufferLoad(...)(Lines 336-341), and the stride loaders (Lines 412-475). In TIR these calls execute eagerly; wrapping the result inif_then_else/Ordoes not short-circuit, so a NULL handle will segfault exactly as before. Please move every struct/Buffer load into statement-level guards (IfThenElse(Not(is_null), …)) or skip the loops entirely when the handle is NULL. This is the same issue coderabbitai[bot] highlighted earlier, but it remains unfixed.Also applies to: 336-352, 412-475
🧹 Nitpick comments (7)
testing/python/jit/test_tilelang_jit_nvrtc.py (3)
367-367: Consider renaming to reflect the tvm_ffi backend.The function name
run_nvrtc_dynamic_shapeand its test wrappertest_nvrtc_dynamic_shape(line 395) suggest NVRTC backend testing, but the implementation now uses"tvm_ffi". This naming inconsistency may confuse developers reviewing or debugging tests.Consider renaming the function to better reflect its purpose, e.g.,
run_tvm_ffi_dynamic_shape, or add a comment explaining that this test validates dynamic shape support using the new default backend.
368-368: Consider removing or conditionalizing the print statement.The host source print appears to be diagnostic output for debugging the tvm_ffi backend migration. In production tests, this clutters output and should either be removed or made conditional (e.g., gated by an environment variable or debug flag).
586-589: Clean up commented code or document the transition.The commented-out
tilelang.testing.main()suggests a transitional state. If the shift to direct test invocation with explicit cache disabling is intentional and permanent, remove the commented line. If this is temporary for debugging the tvm_ffi migration, add a TODO comment explaining the plan to restore the automatic test runner.Note: The direct invocation only exercises one test case (
test_nvrtc_dynamic_shape) rather than all tests in the module, which may reduce test coverage during local development runs.testing/python/jit/test_tilelang_jit_nullptr.py (1)
93-94: Consider removing or conditionalizing the debug print statement.The kernel invocation pattern is correct. However, the unconditional
print(kernel.get_host_source())statement appears to be debug code that will add noise to test output.Consider removing this line or making it conditional:
- print(kernel.get_host_source())Or make it conditional on an environment variable or debug flag if source inspection is needed during development.
src/transform/make_packed_api.cc (3)
23-36: Clarify STL includes:<unordered_set>looks unused whilestd::unordered_mapis used.This file uses
std::unordered_map(line 258) but only includes<unordered_set>. It likely compiles via transitive includes, but for correctness and clarity you probably want:
- Replace
<unordered_set>with<unordered_map>, and- Drop
<unordered_set>entirely unless you actually usestd::unordered_setelsewhere.
52-52: ReturnRewriter FFI packing looks consistent but only supports scalar (non-handle) values.The new
ConvertForFFI/WriteToOutpath correctly mapsbool/int/uint/float/void intoTVMFFIAnyusingkTVMFFI*type indices andtvm_struct_setwrites, and theret(0)convention is preserved. However:
ConvertForFFIcurrently only handlesis_bool/is_int/is_uint/is_float/is_void; anyhandleor other dtype passed totir.retwill triggerLOG(FATAL).- If
tir.retis guaranteed to be used only for scalar return values in this pass, this is fine; otherwise, you may want to extendConvertForFFIto cover handle/object returns as well.Also applies to: 84-107, 109-128
300-371: Type-index checks and relaxed dtype compatibility are reasonable; document Tensor handle offset.The per-argument handling:
- Loads
*.type_indexviatvm_struct_get,- For handles, accepts
None/OpaquePtr/DLTensorPtr/static objects and special-caseskTVMFFITensorby addingsizeof(TVMFFIObject)byte offset,- Allows
Bool↔IntandInt↔Floatvia controlledTypeIndexchecks and casts,matches the stated goal of more permissive, but still checked, binding semantics.
Two small suggestions:
- The comment saying the DLTensor header “always equals 16 bytes” conflicts with
static_assert(sizeof(TVMFFIObject) == 24);— consider updating the comment to match the current layout.- It may be worth briefly documenting somewhere (even as a comment) that the Tensor case requires this
TVMFFIObjectoffset so future changes to the FFI object layout don’t silently break handle bindings.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/transform/arg_binder.cc(6 hunks)src/transform/make_packed_api.cc(13 hunks)testing/python/jit/test_tilelang_jit_nullptr.py(2 hunks)testing/python/jit/test_tilelang_jit_nvrtc.py(2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/jit/test_tilelang_jit_nvrtc.pytesting/python/jit/test_tilelang_jit_nullptr.py
🧬 Code graph analysis (4)
testing/python/jit/test_tilelang_jit_nvrtc.py (3)
tilelang/jit/kernel.py (1)
get_host_source(436-443)tilelang/env.py (1)
disable_cache(275-276)tilelang/language/symbolics.py (1)
dynamic(10-21)
testing/python/jit/test_tilelang_jit_nullptr.py (2)
tilelang/env.py (1)
disable_cache(275-276)tilelang/utils/tensor.py (1)
map_torch_type(19-36)
src/transform/arg_binder.cc (1)
tilelang/language/ast/ir.py (3)
LetStmt(880-908)handle(1467-1497)decl_buffer(1137-1205)
src/transform/make_packed_api.cc (2)
src/target/codegen_cpp.cc (8)
VisitExpr_(166-179)VisitExpr_(166-167)VisitExpr_(375-414)VisitExpr_(375-376)VisitExpr_(452-455)VisitExpr_(452-453)VisitExpr_(457-460)VisitExpr_(457-458)tilelang/language/tir/op.py (1)
tvm_struct_get(414-436)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
- GitHub Check: Build SDist
🔇 Additional comments (6)
testing/python/jit/test_tilelang_jit_nullptr.py (3)
86-87: LGTM!The addition of
disable_cache()and the kernel object pattern align with the PR's transition to TVM FFI-based execution and ensure deterministic test behavior.
101-109: LGTM!The kernel invocation pattern is consistently applied across all test cases. The NULL pointer handling (passing
Nonewhenwith_bias=False) is correct for bothptr_null_testandtensor_null_testvariants.
112-118: The review comment is incorrect—test_nullptr()is not dead code.Pytest auto-discovers test functions matching the
test_*naming convention regardless of the__main__block. The function will be executed when running tests via pytest (standard test discovery), while the__main__block provides an alternative for direct script execution. This dual pattern is common and intentional in test files.However, there is a legitimate concern worth addressing: the file diverges from the project's established pattern. Other test files (e.g.,
test_tilelang_cache_matmul.py,test_metal_codegen.py) invoketilelang.testing.main()in their__main__blocks. The switch to directrun_test()invocation may bypass framework-level setup or test qualification decorators (like@tilelang.testing.requires_cuda). Verify whether this intentional deviation from the project's test invocation pattern is acceptable for this file.Likely an incorrect or invalid review comment.
src/transform/make_packed_api.cc (3)
136-138: Switch toffi::Optional/ffi::Map/ffi::Stringin SubroutineCallRewriter looks sound.Using
ffi::Optional<Stmt>,ffi::Map<GlobalVar, ffi::String>,ffi::GetRef, andffi::Array<PrimExpr>here is consistent with the rest of the FFI-centric changes, andApplyis still consumed inMakePackedAPI()the same way. I don’t see functional regressions in this rewriter.Also applies to: 150-152, 157-160, 175-177
181-188: Assert helpers and improved diagnostics look good.
MakeAssertEQandMakeAssertNotNullare straightforward and their usage for:
num_argsequality, and- guarding against
args == nullptrwhennum_args > 0provides clearer error messages without changing semantics.
Also applies to: 283-293
247-251: Overall FFI packed-API wiring looks correct; a few assumptions to confirm.The reworked packed API:
- Uses
(self_handle, packed_args, num_args, result)as params and setskCallingConv = kCPackedFunc,kTarget = target_host, andkGlobalSymbolwithffi::symbol::tvm_ffi_symbol_prefix + global_symbol.- Binds scalar args via
ArgBinder::Bind, buffers viaBindDLTensor, then builds the body withReturnRewriter(v_result), device attrs (usingffi::Any node = ffi::String("default")), atvm_call_packed(tvm_set_device, ...)prologue when needed, and a finalret(0)error code.- Merges init/binding/assert/buffer decl nests, resets
buffer_mapto empty, and setsret_typetoPrimType(Int(32)), matching the usual packed-func convention.- Checks
UndefinedVars(now usingffi::Array<Var>) to ensure all used vars are inparams.This all lines up with the new TVM FFI calling convention and looks structurally sound.
The main assumptions to keep in mind/verify are:
- External callers indeed pass a
TVMFFIAny*forresultthat matches the layout used inReturnRewriter::WriteToOut.- No PrimFunc relying on this pass expects a non‑
int32C return type anymore (sinceret_typeis fixed toInt(32)and the real value flows viaresult).Also applies to: 374-378, 390-402, 403-427, 431-439
Do those numbers include kernel launch overhead? |
Yes, it includes all cpu overheads with kernel launches. The result is tested with the method: begin_time = time.time()
kernel(...)
# w/o sync
end_time = time.time() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/quickstart.py (1)
67-74: Fix critical bug: M=0 creates empty tensors that don't test the kernel.Setting
M = 0creates empty tensors with zero rows:
ahas shape(0, 1024)chas shape(0, 1024)This means no actual matrix multiplication occurs, and the correctness validation at line 81 doesn't test anything meaningful. For a quickstart example demonstrating dynamic shapes, you should still use a valid test size.
Apply this diff to fix:
-M = 0 +M = 1024Alternatively, if demonstrating multiple dynamic sizes is desired:
-M = 0 +# Test with a valid dynamic size +M = 1024 # Create random input tensors on the GPU a = torch.randn(M, K, device="cuda", dtype=torch.float16)
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_alloc.py (1)
153-153: Apply consistent pattern to allget_kernel_source()calls for clarity.Verification confirms line 153 now explicitly passes
kernel_only=True, while lines 36, 76, and 115 rely on the default parameter (which is alsoTrue). All four calls behave identically. For consistency and clarity, apply the explicit parameter pattern to the other similar test functions:# Line 36 - code = kernel.get_kernel_source() + code = kernel.get_kernel_source(kernel_only=True) # Line 76 - code = kernel.get_kernel_source() + code = kernel.get_kernel_source(kernel_only=True) # Line 115 - code = kernel.get_kernel_source() + code = kernel.get_kernel_source(kernel_only=True)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/blocksparse_gemm/example_blocksparse_gemm.py(0 hunks)examples/gdn/test_example_gdn_compilation.py(1 hunks)examples/quickstart.py(3 hunks)testing/python/debug/test_tilelang_debug_print.py(1 hunks)testing/python/jit/test_tilelang_jit_nullptr.py(1 hunks)testing/python/language/test_tilelang_language_alloc.py(1 hunks)
💤 Files with no reviewable changes (1)
- examples/blocksparse_gemm/example_blocksparse_gemm.py
🚧 Files skipped from review as they are similar to previous changes (2)
- examples/gdn/test_example_gdn_compilation.py
- testing/python/debug/test_tilelang_debug_print.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/jit/test_tilelang_jit_nullptr.pyexamples/quickstart.py
🧬 Code graph analysis (3)
testing/python/jit/test_tilelang_jit_nullptr.py (1)
tilelang/utils/tensor.py (1)
map_torch_type(19-36)
examples/quickstart.py (1)
tilelang/jit/__init__.py (3)
jit(293-294)jit(298-310)jit(313-382)
testing/python/language/test_tilelang_language_alloc.py (5)
tilelang/jit/adapter/nvrtc/adapter.py (1)
get_kernel_source(172-183)tilelang/jit/kernel.py (1)
get_kernel_source(423-434)tilelang/jit/adapter/ctypes/adapter.py (1)
get_kernel_source(294-300)tilelang/jit/adapter/tvm_ffi.py (1)
get_kernel_source(311-316)tilelang/jit/adapter/cython/adapter.py (1)
get_kernel_source(384-391)
🔇 Additional comments (3)
testing/python/jit/test_tilelang_jit_nullptr.py (1)
86-107: LGTM! Kernel object pattern aligns with FFI backend changes.The variable naming changes (using
kernelconsistently) align with the broader PR pattern of using kernel object handles returned by the JIT. The test logic correctly validates null pointer handling for both pointer-based and tensor-based APIs by:
- Testing the None/null bias case first
- Testing with an actual bias tensor
- Verifying correctness against reference computations in both scenarios
This test appropriately validates the null pointer handling improvements mentioned in the PR objectives.
examples/quickstart.py (2)
54-54: Good: Dynamic shape support demonstrated.The change to
M = T.dynamic("m")correctly demonstrates dynamic shape functionality, which aligns with the PR's backend improvements.
8-11: I need to examine the quickstart file more closely to check if the original review comment contained other concerns beyond the pass_configs question.Pass configs are consistent with codebase pattern; no fixes needed for decorator.
Based on codebase analysis, both
TL_DISABLE_TMA_LOWERandTL_DISABLE_WARP_SPECIALIZEDare used extensively across 50+ test files, examples, and maintenance scripts. Multiple files contain explicit comments indicating these are known workarounds:
testing/python/dynamic/test_tilelang_dynamic_symbolic.py:423: "workaround for hopper tma lower pass"examples/flash_decoding/example_gqa_decode.py:49: "TODO(lei): fix warp specialized and tma lower pass"The quickstart's use of these pass configs aligns with the established pattern throughout the codebase, making it appropriate for a simplified demonstration. The approach is consistent, not unique to this example.
|
Just a quick test against current version: gives: Looks like tvm-ffi have similar issue like when dealing with raw data pointer from torch on metal, there's extra sync op that not exposed publicly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
examples/quickstart.py (1)
60-60: LGTM! Good addition for demonstrating kernel source generation.Printing the kernel source after compilation helps users verify the FFI backend is working correctly, which aligns well with this PR's migration to
tvm_ffias the default execution backend.Optional: Consider removing redundant commented code.
Lines 80-82 contain commented-out code that would also print the kernel source, which is now redundant with this addition. You may want to remove those lines to avoid confusion:
-# 4. Retrieve and inspect the generated CUDA source (optional) -# cuda_source = jit_kernel.get_kernel_source() -# print("Generated CUDA kernel:\n", cuda_source)testing/python/jit/test_tilelang_jit_tvm_ffi.py (2)
407-412: Add exception handling for device property access.The function should guard against exceptions when accessing device properties, particularly if CUDA is available but no devices are present.
Apply this diff to add exception handling:
def check_hopper(): - if not torch.cuda.is_available(): + try: + if not torch.cuda.is_available(): + return False + props = torch.cuda.get_device_properties(0) + compute_capability = props.major, props.minor + return compute_capability == (9, 0) + except (RuntimeError, AssertionError): return False - props = torch.cuda.get_device_properties(0) - compute_capability = props.major, props.minor - return compute_capability == (9, 0)
532-581: Consider removing debug print statements.The test correctly validates L2 persistent cache annotation, but includes print statements at lines 569 and 581 that may clutter CI output.
Consider removing or gating these print statements:
# Compile the kernel kernel = elementwise_add_with_l2_cache(M, N) - print(kernel.get_host_source()) # Create test tensors# Verify correctness ref_c = a + b tilelang.testing.torch_assert_close(c, ref_c, atol=1e-5, rtol=1e-5) - - print("L2 persistent map test passed!")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/quickstart.py(1 hunks)testing/python/jit/test_tilelang_jit_tvm_ffi.py(8 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/jit/test_tilelang_jit_tvm_ffi.py
🧬 Code graph analysis (2)
examples/quickstart.py (3)
tilelang/jit/adapter/tvm_ffi.py (1)
get_kernel_source(311-316)tilelang/jit/kernel.py (1)
get_kernel_source(423-434)tilelang/jit/adapter/base.py (1)
get_kernel_source(93-97)
testing/python/jit/test_tilelang_jit_tvm_ffi.py (7)
tilelang/jit/kernel.py (2)
out_idx(598-599)get_profiler(405-421)tilelang/jit/adapter/tvm_ffi.py (2)
func(195-262)get_host_source(299-303)tilelang/language/symbolics.py (1)
dynamic(10-21)tilelang/language/annotations.py (1)
annotate_l2_hit_ratio(48-54)tilelang/layout/swizzle.py (1)
make_swizzled_layout(65-76)tilelang/language/copy.py (2)
c2d_im2col(105-135)copy(14-102)tilelang/utils/tensor.py (1)
torch_assert_close(202-294)
🔇 Additional comments (7)
testing/python/jit/test_tilelang_jit_tvm_ffi.py (7)
93-93: LGTM: Backend migration to tvm_ffi is consistent.The migration from ctypes to
tvm_ffiexecution backend is correctly applied across all test cases. The PR benchmarks show tvm_ffi delivers 1.2–2.1× speedups over ctypes, confirming this is a performance improvement.Also applies to: 194-194, 265-265, 314-314, 367-367, 488-488, 539-539
236-282: LGTM: Function and variable renames reflect tvm_ffi backend.The renaming from
ctypestotvm_ffiimproves code clarity and aligns with the backend migration.
285-335: LGTM: Multi-stream test provides good concurrency coverage.The test correctly validates tvm_ffi backend behavior with multiple CUDA streams, an important scenario for concurrent kernel execution.
338-404: LGTM: Dynamic shape testing is thorough.The function correctly handles symbolic dimensions and tests multiple dynamic shape scenarios (1D, 2D, 3D dynamic), providing strong coverage for TVM FFI's dynamic shape support.
415-467: LGTM: Convolution im2col implementation is correct.The function correctly implements im2col-based convolution with proper:
- Output dimension calculations
- TMA descriptor usage via
T.c2d_im2col- Swizzled layout annotations for shared memory optimization
470-529: LGTM: TMA descriptor test is well-structured.The test correctly validates im2col TMA descriptor functionality with:
- Appropriate gating for Hopper GPU requirements
- Correct layout transformations (NHWC ↔ NCHW)
- Proper reference implementation using
torch.conv2d
585-585: LGTM: Test runner migration improves flexibility.Using
tilelang.testing.main()enables proper test discovery and allows running the full test suite, which is better than hardcoding a specific test function.
…rsistent cache management - Introduced functions to set and reset the CUDA stream access policy window, allowing for better control over L2 cache usage. - Updated runtime files to include new FFI packed functions for managing stream attributes. - Modified lower_hopper_intrin to incorporate prologue and epilogue statements for L2 cache setup and teardown. - Enhanced tests to verify the inclusion of new FFI calls in the generated kernel source.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (1)
src/transform/lower_hopper_intrin.cc (1)
35-35: Critical allocation size error: Change 16 elements to 8 elements for TVMFFIAny migration.This issue was previously identified: the migration from TVMValue (8 bytes) to TVMFFIAny (16 bytes) failed to account for the size difference. The TensorMap descriptor requires 128 bytes (1024 bits).
Current allocation:
- 16 elements × 16 bytes = 256 bytes ❌
Correct allocation:
- 8 elements × 16 bytes = 128 bytes ✓
Apply this fix:
- {StringImm("tvm_ffi_any"), 16}); + {StringImm("tvm_ffi_any"), 8});This double over-allocation will cause memory corruption and crashes during TensorMap descriptor initialization on Hopper GPUs.
🧹 Nitpick comments (2)
testing/python/jit/test_tilelang_jit_tvm_ffi.py (1)
571-572: Consider breaking long assertion messages.The assertion messages on lines 571-572 exceed typical line length limits and reduce readability.
Consider formatting as:
- assert "__tvm_cuda_stream_set_access_policy_window_packed" in source, "Expected __tvm_cuda_stream_set_access_policy_window_packed in the kernel source" - assert "__tvm_cuda_stream_reset_access_policy_window_packed" in source, "Expected __tvm_cuda_stream_reset_access_policy_window_packed in the kernel source" + assert "__tvm_cuda_stream_set_access_policy_window_packed" in source, ( + "Expected __tvm_cuda_stream_set_access_policy_window_packed in the kernel source" + ) + assert "__tvm_cuda_stream_reset_access_policy_window_packed" in source, ( + "Expected __tvm_cuda_stream_reset_access_policy_window_packed in the kernel source" + )src/runtime/runtime.cc (1)
16-20: Remove#if 1guard or clarify intent.The
#if 1preprocessor directive on line 16 is unconditional and typically indicates temporary debugging code or a feature toggle. If the thread-local storage for L2 cache restoration is intended to be permanent, remove the guard. Otherwise, add a comment explaining the condition.-#if 1 -// Thread-local storage for restoring the L2 persisting cache limit static thread_local size_t __tl_prev_persisting_l2_cache_size = 0; static thread_local bool __tl_prev_persisting_l2_cache_saved = false; -#endif
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/runtime/runtime.cc(4 hunks)src/runtime/runtime.h(1 hunks)src/transform/lower_hopper_intrin.cc(2 hunks)testing/python/jit/test_tilelang_jit_tvm_ffi.py(1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
src/transform/lower_hopper_intrin.cc
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
src/runtime/runtime.cctesting/python/jit/test_tilelang_jit_tvm_ffi.py
🧬 Code graph analysis (2)
src/transform/lower_hopper_intrin.cc (2)
src/transform/lower_tile_op.cc (2)
f(197-224)f(197-197)src/runtime/runtime.cc (4)
args(47-77)args(47-47)args(131-166)args(131-131)
testing/python/jit/test_tilelang_jit_tvm_ffi.py (4)
src/runtime/runtime.h (1)
tvm(10-26)tilelang/utils/tensor.py (2)
map_torch_type(19-36)torch_assert_close(202-294)tilelang/jit/adapter/tvm_ffi.py (4)
prim_func(319-321)get_kernel_source(311-316)func(195-262)get_host_source(299-303)tilelang/language/annotations.py (2)
annotate_layout(26-37)annotate_l2_hit_ratio(48-54)
🔇 Additional comments (4)
testing/python/jit/test_tilelang_jit_tvm_ffi.py (1)
88-91: Callback registration persists across tests.The global function
tilelang_callback_cuda_postprocis registered withoverride=Truebut never restored. This is consistent with the expected persistent behavior noted in similar test files.Based on learnings.
src/transform/lower_hopper_intrin.cc (1)
56-96: L2 persistent cache window logic is well-structured.The new logic to materialize TVM FFI calls for L2 cache access policy window management is clear and correctly handles:
- Building a name-to-buffer lookup from the function's buffer map
- Extracting base pointers and arguments from the l2_persistent_map attribute
- Properly ordering prologue (set policy) and epilogue (reset policy) calls
src/runtime/runtime.cc (2)
216-346: Comprehensive L2 cache access policy window implementation.The new FFI functions for CUDA stream access policy window management are well-implemented with:
- Thorough argument validation
- Device capability querying and clamping
- Proper state preservation via thread-local storage
- Comprehensive error handling with informative messages
- Correct CUDA API usage for stream attributes and persisting L2 cache
The thread-local storage pattern ensures per-thread restoration state, which is appropriate for multi-threaded environments where each thread may manage different streams.
101-114: FFI registration updated to use canonical names.The migration from string-based registration (
"tvm_tensormap_create_tiled") to canonical constant-based registration (tl::tvm_tensormap_create_tiled) improves type safety and maintainability by centralizing the naming inruntime.h.
| // CUDA stream access policy window helpers | ||
| constexpr const char *tvm_cuda_stream_set_access_policy_window = | ||
| "__tvm_cuda_stream_set_access_policy_window"; | ||
| constexpr const char *tvm_cuda_stream_reset_access_policy_window = | ||
| "__tvm_cuda_stream_reset_access_policy_window"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify CUDA version requirements for stream access policy window APIs.
The new constants tvm_cuda_stream_set_access_policy_window and tvm_cuda_stream_reset_access_policy_window are declared outside the #if (CUDA_MAJOR_VERSION >= 12) guard (lines 13-18), unlike the TMA descriptor constants above them. Ensure this is intentional and that the CUDA stream access policy window APIs (cuStreamSetAttribute with CU_STREAM_ATTRIBUTE_ACCESS_POLICY_WINDOW, cuCtxSetLimit, cuCtxResetPersistingL2Cache) are available in CUDA versions prior to 12.
🌐 Web query:
CUDA stream access policy window cuStreamSetAttribute minimum CUDA version
💡 Result:
The stream "access policy window" (cudaStreamAttributeAccessPolicyWindow / accessPolicyWindow used with cudaStreamSetAttribute / cuLaunchAttribute) was introduced in CUDA 11.0 (documented in the CUDA 11 features announcement and in the CUDA Programming Guide). [1][2]
Sources:
- NVIDIA blog "CUDA 11 Features Revealed". [1]
- NVIDIA CUDA Programming Guide (examples showing cudaStreamSetAttribute accessPolicyWindow; present in CUDA 11.x documentation). [2]
Add CUDA version guard for stream access policy window constants.
The stream access policy window APIs were introduced in CUDA 11.0, but the constants tvm_cuda_stream_set_access_policy_window and tvm_cuda_stream_reset_access_policy_window (lines 20–24) lack any version guard. They should be wrapped with #if (CUDA_MAJOR_VERSION >= 11) to prevent use in earlier CUDA versions where these APIs are unavailable.
🤖 Prompt for AI Agents
In src/runtime/runtime.h around lines 20 to 24, the CUDA stream access policy
window constant definitions are unguarded and may be referenced on CUDA versions
older than 11; wrap the two constants tvm_cuda_stream_set_access_policy_window
and tvm_cuda_stream_reset_access_policy_window inside a compile-time guard such
as #if (CUDA_MAJOR_VERSION >= 11) ... #endif so they are only defined when
building against CUDA 11.0+; ensure the guard exactly encloses the constexpr
declarations and retains existing formatting and indentation.
| def matmu_jit_kernel( | ||
| M, | ||
| N, | ||
| K, | ||
| block_M, | ||
| block_N, | ||
| block_K, | ||
| trans_A, | ||
| trans_B, | ||
| in_dtype, | ||
| out_dtype, | ||
| accum_dtype, | ||
| num_stages, | ||
| threads, | ||
| ): | ||
| A_shape = (K, M) if trans_A else (M, K) | ||
| B_shape = (N, K) if trans_B else (K, N) | ||
| A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) | ||
| B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) | ||
|
|
||
| @T.prim_func | ||
| def main( | ||
| A: T.Tensor(A_shape, in_dtype), | ||
| B: T.Tensor(B_shape, in_dtype), | ||
| C: T.Tensor((M, N), out_dtype), | ||
| ): | ||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): | ||
| A_shared = T.alloc_shared(A_shared_shape, in_dtype) | ||
| B_shared = T.alloc_shared(B_shared_shape, in_dtype) | ||
| C_local = T.alloc_fragment((block_M, block_N), accum_dtype) | ||
| T.clear(C_local) | ||
| for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): | ||
| if trans_A: | ||
| T.copy(A[k * block_K, by * block_M], A_shared) | ||
| else: | ||
| T.copy(A[by * block_M, k * block_K], A_shared) | ||
| if trans_B: | ||
| T.copy(B[bx * block_N, k * block_K], B_shared) | ||
| else: | ||
| T.copy(B[k * block_K, bx * block_N], B_shared) | ||
| T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) | ||
| T.copy(C_local, C[by * block_M, bx * block_N]) | ||
|
|
||
| return main |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Code duplication: matmu_jit_kernel is identical to matmul.
The function matmu_jit_kernel (lines 117-160) is a duplicate of the matmul function (lines 9-52) with only a typo in the name. This violates DRY principles and creates maintenance burden.
Apply this diff to remove the duplicate:
-def matmu_jit_kernel(
- M,
- N,
- K,
- block_M,
- block_N,
- block_K,
- trans_A,
- trans_B,
- in_dtype,
- out_dtype,
- accum_dtype,
- num_stages,
- threads,
-):
- A_shape = (K, M) if trans_A else (M, K)
- B_shape = (N, K) if trans_B else (K, N)
- A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
- B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
-
- @T.prim_func
- def main(
- A: T.Tensor(A_shape, in_dtype),
- B: T.Tensor(B_shape, in_dtype),
- C: T.Tensor((M, N), out_dtype),
- ):
- with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
- A_shared = T.alloc_shared(A_shared_shape, in_dtype)
- B_shared = T.alloc_shared(B_shared_shape, in_dtype)
- C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
- T.clear(C_local)
- for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
- if trans_A:
- T.copy(A[k * block_K, by * block_M], A_shared)
- else:
- T.copy(A[by * block_M, k * block_K], A_shared)
- if trans_B:
- T.copy(B[bx * block_N, k * block_K], B_shared)
- else:
- T.copy(B[k * block_K, bx * block_N], B_shared)
- T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
- T.copy(C_local, C[by * block_M, bx * block_N])
-
- return main
-
-
def run_gemm_jit_kernel(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
- program = matmu_jit_kernel(
+ program = matmul(
M,
N,
K,Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
testing/python/jit/test_tilelang_jit_tvm_ffi.py lines 117-160: the function
matmu_jit_kernel is a duplicate (typo) of the existing matmul function; remove
the duplicate definition and update any local references to call matmul instead
(or add a one-line alias like matmu_jit_kernel = matmul if backward-compatible
name must be preserved), ensuring there are no remaining imports/usages of the
removed function.
- Added `codegen_c_host.cc` to the list of source files in CMakeLists.txt for improved code generation support. - Updated the function call in `lower.py` to use `target.build.tilelang_c` for C target host code generation, enhancing compatibility. - Marked the TVM subproject as dirty to indicate local modifications.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
src/transform/arg_binder.cc (1)
317-337: Reconsider accepting arbitrary 1‑bit dtypes for bool compatibility.The bool branch currently allows any 1‑bit dtype, regardless of type code:
PrimExpr kdlbool1_ok = (v_type_code == code_kdlbool && v_type_bits == bits1 && lanes_ok); // Also accept any dtype whose bitwidth=1, regardless of code, to be // defensive. PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok); cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok;This is quite permissive and could mask type issues if a future/custom 1‑bit type is introduced that doesn’t have bool semantics. Restricting the accepted set to the explicitly listed codes (int8/uint8/kDLBool) keeps the behavior predictable while still handling common frontend quirks.
You can drop the
bit1_okterm:- PrimExpr kdlbool1_ok = - (v_type_code == code_kdlbool && v_type_bits == bits1 && lanes_ok); - // Also accept any dtype whose bitwidth=1, regardless of code, to be - // defensive. - PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok); - cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok; + PrimExpr kdlbool1_ok = + (v_type_code == code_kdlbool && v_type_bits == bits1 && lanes_ok); + cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok;
🧹 Nitpick comments (4)
src/target/codegen_c_host.cc (1)
350-453: AssertStmt formatting is helpful; consider dtype‑aware value printingThe enriched assert handling (collecting all
EQsubexpressions, building a single message, escaping format strings, and usingTVMFFIErrorSetRaisedFromCStr) is a solid usability improvement.One thing to consider: the printed values always go through
(long long)and%lld, regardless of whether the operands are signed, unsigned, or floating‑point. For non‑integer or wider types this can produce misleading values or implementation‑defined casts.If this becomes an issue in practice, a follow‑up could:
- Inspect
eq->a.dtype()/eq->b.dtype()and choose format specifiers (%lld,%llu,%f, etc.) accordingly, or- Restrict the “got/expected” detail to clearly integral dtypes and fall back to the base message otherwise.
Not blocking, but worth keeping in mind if you start asserting on non‑integer expressions.
src/target/codegen_c_host.h (1)
55-69: Tighten header API: unusedAddFunctionsOrdered+ FFI dependenciesTwo small points about the header:
AddFunctionsOrderedis declared but doesn’t have a definition incodegen_c_host.cc. It’s harmless today because nothing calls it, but it will become a linker landmine if used later. Either implement it or drop it from the public interface until you need it.The class interface uses
tvm::ffi::Stringandtvm::ffi::Arrayin method signatures and thefunction_names_member but the header relies on transitive includes for their declarations. To make this header self‑contained (and friendlier to external users), consider explicitly including the relevant FFI forward/definition headers (or adding forward declarations if TVM provides them).Both are non‑blocking but would make the API surface safer and clearer.
Also applies to: 91-119
src/transform/arg_binder.h (1)
157-159: Clarify BindNullable API contract and naming.The new
BindNullableAPI is fine functionally, but it’s undocumented in the header and useswith_letswhileBinduseswith_let. Adding a short comment explainingnullable_guardsemantics (e.g., “guard expression that disables asserts/bindings when true”) and aligning the naming withBind/Bind_would make the public interface clearer.src/transform/arg_binder.cc (1)
136-159: Tighten FP8 dtype compatibility by also checking lane count.The
dtype_compatiblehelper relaxes FP8 compatibility:if (expected.is_float8_e4m3()) { return provided.is_float8_e4m3() || provided.is_float8_e4m3fn() || provided.is_float8_e4m3fnuz(); } if (expected.is_float8_e5m2()) { return provided.is_float8_e5m2() || provided.is_float8_e5m2fnuz(); }Unlike the bool path, these branches don’t enforce
expected.lanes() == provided.lanes(), so vectorized FP8 buffers with mismatched lanes could bind successfully. To keep the relaxation limited to “closely related types” without silently accepting lane mismatches, consider requiring equal lanes here as well.For example:
- if (expected.is_float8_e4m3()) { - return provided.is_float8_e4m3() || provided.is_float8_e4m3fn() || - provided.is_float8_e4m3fnuz(); - } + if (expected.is_float8_e4m3()) { + if (expected.lanes() != provided.lanes()) + return false; + return provided.is_float8_e4m3() || provided.is_float8_e4m3fn() || + provided.is_float8_e4m3fnuz(); + } @@ - if (expected.is_float8_e5m2()) { - return provided.is_float8_e5m2() || provided.is_float8_e5m2fnuz(); - } + if (expected.is_float8_e5m2()) { + if (expected.lanes() != provided.lanes()) + return false; + return provided.is_float8_e5m2() || provided.is_float8_e5m2fnuz(); + }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
CMakeLists.txt(1 hunks)src/target/codegen_c_host.cc(1 hunks)src/target/codegen_c_host.h(1 hunks)src/transform/arg_binder.cc(7 hunks)src/transform/arg_binder.h(1 hunks)tilelang/engine/lower.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
src/transform/arg_binder.cc (1)
tilelang/language/ast/ir.py (3)
LetStmt(880-908)handle(1467-1497)decl_buffer(1137-1205)
src/target/codegen_c_host.cc (1)
src/target/codegen_c_host.h (2)
CodeGenCHost(42-119)GenerateForwardFunctionDeclarations(83-89)
src/transform/arg_binder.h (1)
src/transform/arg_binder.cc (2)
BindNullable(54-89)BindNullable(54-56)
src/target/codegen_c_host.h (1)
src/target/codegen_c_host.cc (35)
CodeGenCHost(42-44)Init(46-63)Init(46-48)InitGlobalContext(65-68)InitGlobalContext(65-65)AddFunction(74-77)AddFunction(74-75)AddFunction(79-106)AddFunction(79-81)DefineModuleName(70-72)DefineModuleName(70-70)PrintType(137-211)PrintType(137-137)PrintFuncPrefix(131-135)PrintFuncPrefix(131-131)VisitExpr_(213-226)VisitExpr_(213-214)VisitExpr_(315-348)VisitExpr_(315-316)VisitExpr_(455-458)VisitExpr_(455-456)VisitExpr_(460-463)VisitExpr_(460-461)VisitStmt_(350-453)VisitStmt_(350-350)GenerateForwardFunctionDeclarations(108-129)GenerateForwardFunctionDeclarations(108-110)GetPackedName(296-313)GetPackedName(296-296)PrintGetFuncFromBackend(228-246)PrintGetFuncFromBackend(228-229)PrintCallPacked(248-294)PrintCallPacked(248-248)PrintTernaryCondExpr(466-477)PrintTernaryCondExpr(466-467)
🔇 Additional comments (5)
CMakeLists.txt (1)
135-146: Wirecodegen_c_host.ccinto build – looks correctAdding
src/target/codegen_c_host.cctoTILE_LANG_SRCScleanly integrates the new C host codegen into the build with no side effects on existing sources.tilelang/engine/lower.py (1)
137-152: Align host/device codegen with TVM‑FFI build entries; verify symbol availabilitySwitching the C host path to
"target.build.tilelang_c"and adding the Metal path via"target.build.metal"keeps the lowering logic consistent with the new TVM‑FFI build APIs and looks structurally sound.Please double‑check at runtime that both FFI symbols are actually registered (e.g.,
tvm.ffi.get_global_func("target.build.tilelang_c")and"target.build.metal"return non‑None) to avoid lateGlobalFunc not founderrors, especially on Metal where you’ve seen execution issues.Also applies to: 170-188
src/target/codegen_c_host.cc (1)
42-72: Makemodule_name_consistent with the emitted library‑context symbol
module_name_is initialized vianame_supply_->FreshName(tvm::ffi::symbol::tvm_ffi_library_ctx), whileInitGlobalContexthard‑codes the global variable name usingtvm::ffi::symbol::tvm_ffi_library_ctx.PrintGetFuncFromBackendthen passesmodule_name_intoTVMBackendGetFuncFromEnv.This effectively relies on
FreshNamereturning the bare symbol name for the first use; if its behavior changes (or if the name is consumed earlier), you could end up callingTVMBackendGetFuncFromEnvwith an identifier that has no corresponding definition in the generated C (e.g.,tvm_ffi_library_ctx0).To make this robust and avoid surprising compile‑time breakages, consider either:
- Setting
module_name_directly from the symbol constant and using it consistently:- CodeGenCHost::CodeGenCHost() { - module_name_ = name_supply_->FreshName(tvm::ffi::symbol::tvm_ffi_library_ctx); - } + CodeGenCHost::CodeGenCHost() { + module_name_ = tvm::ffi::symbol::tvm_ffi_library_ctx; + }and optionally rewriting
InitGlobalContextto emitmodule_name_, or
- Dropping
module_name_entirely and usingtvm::ffi::symbol::tvm_ffi_library_ctxdirectly inPrintGetFuncFromBackend.Also,
DefineModuleName()is currently never called; either wire it into initialization/Finish()or remove it to avoid dead API surface.
[selective_essential_refactor:=suggest_essential_refactor]Also applies to: 228-246
src/transform/arg_binder.cc (2)
54-89: BindNullable implementation matches Bind_ semantics and uses guard correctly.The implementation mirrors
Bind_for both Var and non‑Var expressions, and thenullable_guardis only threaded into the equality checks viaOr(nullable_guard, ...), so first-bind behavior is unchanged while later binds gain the intended short‑circuiting. This looks correct for the currentBindDLTensorusage.
229-611: NULL‑aware DLTensor binding and guards look consistent and robust.The new
is_nullhandling inBindDLTensor(Let‑bound*_is_null, guardedTVMArrayGetviaif_then_else, and use ofBindNullablefor shape/strides/elem_offset/device_id/data) gives a clear and consistent story:
- Optional
DLTensor*handles can be NULL without dereferences, while asserts are short‑circuited withOr(is_null, ...).- Shapes/strides/byte_offset/device_type are only validated when non‑NULL, and symbolic vars are still bound via guarded expressions.
- The data pointer check correctly skips size‑0 arrays and NULL handles.
Overall this is a solid improvement in safety and error reporting for optional inputs.
* [Refactor] Update FFI type handling and simplify argument management * Refactored FFI type definitions in runtime and code generation files to use `TVMFFIAny` instead of `TVMValue`, enhancing type clarity. * Updated function registration in `runtime.cc` to utilize canonical names for better consistency. * Simplified argument handling in the `simplify` transformation, ensuring unused buffer parameters are removed only when simplification is enabled. * Adjusted autotuner and profiler parameters to standardize the execution backend to `tvm_ffi`, improving clarity in backend selection. * Removed obsolete `adapt_torch2tvm` function from tensor utilities to streamline the codebase and reduce complexity. * [Update] Sync TVM submodule and enhance kernel source handling * Updated the TVM submodule to commit cdc2aced, ensuring compatibility with recent changes. * Added functionality to print kernel source in `example_blocksparse_gemm.py` for better debugging. * Commented out the main execution call in test files to prevent unintended execution during testing. * Introduced `tilelang.disable_cache()` in various test files to streamline testing and avoid cache-related issues. * Refactored kernel source retrieval methods to improve clarity and consistency across different execution backends. * [Refactor] Clean up imports and improve code formatting * Removed unused import of `tilelang.testing` in `test_example_blocksparse_gemm.py` to streamline the code. * Reformatted several lines in `arg_binder.cc`, `make_packed_api.cc`, `tvm_ffi.py`, and `adapter.py` for improved readability and consistency. * Updated comments and spacing in `tvm_ffi.py` to enhance clarity without altering functionality. * Update execution backend options and improve resolution logic - Changed default execution backend from "cython" to "auto" in multiple locations to allow automatic selection based on the target. - Expanded the list of supported execution backends to include "torch" and "nvrtc" across various classes and functions. - Enhanced backend resolution logic in `KernelCache` and `AutoTuner` to ensure appropriate backend selection based on the target. - Updated documentation to reflect changes in execution backend options and their defaults. * lint fix * fix * Enhance argument handling in CUDA and HIP runtime modules - Updated `ExtractFuncInfo` in `rt_mod_cuda.cc` and `rt_mod_hip.cc` to map boolean argument types to int32, ensuring compatibility with device runtime. - Refactored `BindDLTensor` in `arg_binder.cc` to improve null handling and validation checks for DLTensor parameters, utilizing expression-level guards to prevent dereferencing null pointers. - Enhanced error checking for buffer shape, strides, and data fields, ensuring robust handling of optional inputs and maintaining consistency across various checks. * lint fix * lint fix * lint fix * lint fix * minor fix * fix * recover check * Refactor argument binding and validation in `arg_binder.cc` - Improved null handling and validation checks in `BindDLTensor`, ensuring safe dereferencing of pointers. - Enhanced consistency checks for buffer shape, strides, and data fields, utilizing expression-level guards. - Updated `MakePackedAPI` to maintain code clarity and consistency in argument handling. - Minor adjustments in test files to streamline kernel execution and improve readability. * lint fix * stride fix * minor fix * fix * lint fix * lint fix * Add CUDA stream access policy window helpers and integrate with L2 persistent cache management - Introduced functions to set and reset the CUDA stream access policy window, allowing for better control over L2 cache usage. - Updated runtime files to include new FFI packed functions for managing stream attributes. - Modified lower_hopper_intrin to incorporate prologue and epilogue statements for L2 cache setup and teardown. - Enhanced tests to verify the inclusion of new FFI calls in the generated kernel source. * check with symbolic * support null ptr * Update CMakeLists and lower.py for code generation and subproject status - Added `codegen_c_host.cc` to the list of source files in CMakeLists.txt for improved code generation support. - Updated the function call in `lower.py` to use `target.build.tilelang_c` for C target host code generation, enhancing compatibility. - Marked the TVM subproject as dirty to indicate local modifications. * lint fix * Update comments for clarity in quickstart.py

This pull request introduces several improvements and bug fixes across the codebase, focusing on enhanced dtype compatibility in argument binding, improved error messages, and some code maintenance and cleanup. The most significant changes are grouped below:
Enhanced dtype compatibility and error messaging
ArgBinder::BindBufferandBindDLTensorto allow more flexible binding between closely related types (e.g., different float8 variants and bool/int8/uint8), improving support for mixed precision and defensive programming. Error messages were also made more informative and user-friendly, avoiding dumping raw TIR expressions. [1] [2] [3]FFI and API consistency
runtime.ccto use canonical names and improved formatting for error messages when initializing TMA descriptors. [1] [2]TVMFFIAnyinstead ofTVMValuefor FFI calls, aligning with updated FFI conventions. [1] [2]LowerHopperIntrinto use the correct type name (tvm_ffi_any).Maintenance and test improvements
Codebase organization
make_packed_api.cc, and simplified theReturnRewriterlogic by removing unused variables and improving type conversion handling. [1] [2] [3] [4]These changes collectively improve robustness, flexibility, and maintainability, especially in how the code handles data types and error reporting.
Summary by CodeRabbit
New Features
Bug Fixes
Refactor
Chores