Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Nov 14, 2025

This pull request introduces several improvements and bug fixes across the codebase, focusing on enhanced dtype compatibility in argument binding, improved error messages, and some code maintenance and cleanup. The most significant changes are grouped below:

Enhanced dtype compatibility and error messaging

  • Relaxed dtype checks in ArgBinder::BindBuffer and BindDLTensor to allow more flexible binding between closely related types (e.g., different float8 variants and bool/int8/uint8), improving support for mixed precision and defensive programming. Error messages were also made more informative and user-friendly, avoiding dumping raw TIR expressions. [1] [2] [3]
  • Improved runtime assertion messages for null pointers and stride mismatches to provide clearer debugging information. [1] [2]

FFI and API consistency

  • Updated FFI registration in runtime.cc to use canonical names and improved formatting for error messages when initializing TMA descriptors. [1] [2]
  • Changed code generation to use TVMFFIAny instead of TVMValue for FFI calls, aligning with updated FFI conventions. [1] [2]
  • Updated stack allocation in LowerHopperIntrin to use the correct type name (tvm_ffi_any).

Maintenance and test improvements

  • Updated the TVM submodule to a newer commit.
  • Cleaned up imports and test invocation logic in several example test scripts, disabling unused test runners and adding cache disabling for compilation tests. [1] [2] [3] [4] [5] [6] [7]

Codebase organization

  • Added a missing include and a minor code cleanup in make_packed_api.cc, and simplified the ReturnRewriter logic by removing unused variables and improving type conversion handling. [1] [2] [3] [4]

These changes collectively improve robustness, flexibility, and maintainability, especially in how the code handles data types and error reporting.

Summary by CodeRabbit

  • New Features

    • TVM‑FFI execution backend, automatic "auto" backend resolution, Metal target support, and a new C‑host codegen option.
  • Bug Fixes

    • NULL‑safe argument handling, relaxed dtype matching for bool/FP8, improved boolean handling for device calls, and CUDA stream access‑policy hooks.
  • Refactor

    • Kernel sources split into host/device and APIs migrated to an FFI‑centric model with unified backend resolution.
  • Chores

    • Removed legacy DLPack/Torch adapters, test updates, and submodule pointer update.

* Refactored FFI type definitions in runtime and code generation files to use `TVMFFIAny` instead of `TVMValue`, enhancing type clarity.
* Updated function registration in `runtime.cc` to utilize canonical names for better consistency.
* Simplified argument handling in the `simplify` transformation, ensuring unused buffer parameters are removed only when simplification is enabled.
* Adjusted autotuner and profiler parameters to standardize the execution backend to `tvm_ffi`, improving clarity in backend selection.
* Removed obsolete `adapt_torch2tvm` function from tensor utilities to streamline the codebase and reduce complexity.
* Updated the TVM submodule to commit cdc2aced, ensuring compatibility with recent changes.
* Added functionality to print kernel source in `example_blocksparse_gemm.py` for better debugging.
* Commented out the main execution call in test files to prevent unintended execution during testing.
* Introduced `tilelang.disable_cache()` in various test files to streamline testing and avoid cache-related issues.
* Refactored kernel source retrieval methods to improve clarity and consistency across different execution backends.
* Removed unused import of `tilelang.testing` in `test_example_blocksparse_gemm.py` to streamline the code.
* Reformatted several lines in `arg_binder.cc`, `make_packed_api.cc`, `tvm_ffi.py`, and `adapter.py` for improved readability and consistency.
* Updated comments and spacing in `tvm_ffi.py` to enhance clarity without altering functionality.
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 14, 2025

Caution

Review failed

The pull request is closed.

Walkthrough

Pivot runtime and JIT to a TVM‑FFI–first model: add backend resolution, TVM‑FFI adapter and FFI‑packed APIs, NULL‑safe DLTensor binding, host/device source separation, a TileLang C host codegen, CUDA stream access‑policy FFI hooks, and extensive cache, adapter, and test updates. (50 words)

Changes

Cohort / File(s) Summary
Submodule & Packaging
3rdparty/tvm, pyproject.toml
TVM submodule pointer update; add TVM include mapping for wheel packaging.
Examples & Small edits
examples/..., testing/python/debug/test_tilelang_debug_print.py, examples/quickstart.py
Minor formatting and debug-print tweaks; tests/quickstart updated to prefer tvm_ffi and kernel_only kernel source access.
Runtime FFI & Headers
src/runtime/runtime.cc, src/runtime/runtime.h
Replace string FFI names with canonical tl:: symbols; add FFI regs to set/reset CUDA stream access policy window and init helpers to manage L2 persisting cache window.
FFI type/codegen tweaks
src/target/codegen_cpp.cc, src/target/rt_mod_cuda.cc, src/target/rt_mod_hip.cc
Use TVMFFIAny for temporary FFI return slots; map bool PrimFunc params to int32 for device arg lists.
Packed API refactor
src/transform/make_packed_api.cc
Large migration to tvm::ffi types and FFI‑centric packing/arg paths; many public/type signature changes and new packing semantics.
Arg binding NULL safety
src/transform/arg_binder.cc, src/transform/arg_binder.h
Add ArgBinder::BindNullable; make DLTensor binding NULL‑guarded with conditional loads; relax some dtype compatibility checks.
Hopper lowering prologue/epilogue
src/transform/lower_hopper_intrin.cc
Use tvm_ffi_any stack arg, collect prologue/epilogue statements, insert FFI calls to set/reset L2 access policy for annotated maps.
Simplify transform
src/transform/simplify.cc
Optionally drop unused buffer parameters when simplify_arguments is true and rebuild params/buffer_map.
C host codegen
src/target/codegen_c_host.cc, src/target/codegen_c_host.h, CMakeLists.txt
Add TileLang C host CodeGenCHost (header + source), register target.build.tilelang_c, and add file to build.
Execution backend resolver
tilelang/jit/execution_backend.py
New canonicalization and resolution logic for execution backends (handles auto, availability, per‑target defaults).
JIT surface & defaults
tilelang/jit/__init__.py, tilelang/jit/kernel.py
Expand execution_backend literals to include auto, tvm_ffi, torch, nvrtc; default → auto/tvm_ffi; resolve backend per target; get_kernel_source(kernel_only) support.
Kernel cache & disk layout
tilelang/cache/kernel_cache.py, tilelang/cache/__init__.py
Separate device/host kernel paths, add EXECUTABLE_PATH, support saving/loading Executable, normalize targets/backends, update API signatures/defaults.
Adapter base & helpers
tilelang/jit/adapter/base.py
Add get_current_stream_functor and get_current_device_functor; get_kernel_source(self, kernel_only: bool = True) behavior.
Adapter refactors (ctypes/cython/nvrtc)
tilelang/jit/adapter/ctypes/adapter.py, tilelang/jit/adapter/cython/adapter.py, tilelang/jit/adapter/nvrtc/adapter.py
Split kernel_global_source into host_kernel_source + device_kernel_source; update constructors, from_database, and get_kernel_source(kernel_only).
New TVM‑FFI adapter
tilelang/jit/adapter/tvm_ffi.py, tilelang/jit/adapter/__init__.py
Add TVMFFIKernelAdapter: dynamic‑shape mapping, stream/device capture, lazy runtime/executable instantiation, validations, host/device source accessors; export adapter.
Remove DLPack/Torch adapters & helpers
tilelang/contrib/dlpack.py, tilelang/jit/adapter/dlpack.py, tilelang/utils/tensor.py, tilelang/profiler/__init__.py
Remove to_pytorch_func, TorchDLPackKernelAdapter, adapt_torch2tvm; profiler no longer converts inputs via that helper.
Autotuner & tuner updates
tilelang/autotuner/param.py, tilelang/autotuner/tuner.py
Expand execution_backend literals/defaults to include auto, tvm_ffi, torch; resolve backend in set_compile_args; add metal target literal.
Tests: TVM‑FFI additions & adjustments
testing/python/jit/test_tilelang_jit_tvm_ffi.py, various tests
Add comprehensive TVM‑FFI tests (GEMM, im2col/TMA, L2‑persistent hooks, dynamic shapes); update tests to use kernel_only and harness main; minor callsite edits.
Misc edits
examples/gdn/..., examples/blocksparse_gemm/..., other small test files
Remove/adjust debug prints and kernel source print usage; small formatting changes.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant User
    participant JIT as tilelang.jit
    participant Resolver as execution_backend
    participant Adapter as TVMFFIKernelAdapter
    participant TVM as tvm.runtime

    User->>JIT: compile(func, target="cuda", execution_backend="auto")
    JIT->>Resolver: resolve_execution_backend("auto", target)
    Resolver-->>JIT: "tvm_ffi"
    JIT->>Adapter: instantiate TVMFFIKernelAdapter(...)
    Adapter->>Adapter: _process_dynamic_symbolic()
    Adapter-->>JIT: return kernel object
    User->>Adapter: call kernel(torch_tensors)
    Adapter->>Adapter: get_current_device()/get_current_stream()
    Adapter->>TVM: invoke executable(device_tensors)
    TVM-->>Adapter: outputs
    Adapter-->>User: torch tensors
Loading
sequenceDiagram
    autonumber
    participant MakePackedAPI
    participant ArgBinder
    participant Runtime as runtime.cc

    MakePackedAPI->>ArgBinder: BindDLTensor(handle,...)
    Note right of ArgBinder: insert is_null guard\nuse conditional loads to avoid NULL deref
    ArgBinder-->>MakePackedAPI: bound, guarded tensor values
    MakePackedAPI->>Runtime: def_packed(tl::tvm_tensormap_create_..., ...)
    Runtime-->>MakePackedAPI: registration complete
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Areas needing focused review:

  • src/transform/make_packed_api.cc — large FFI refactor and public-type migrations.
  • src/transform/arg_binder.{cc,h} — NULL‑guard correctness and conditional-load semantics.
  • tilelang/jit/adapter/tvm_ffi.py — dynamic‑shape handling, stream/device capture, and executable lifecycle.
  • tilelang/cache/kernel_cache.py — disk layout/backends, Executable atomic save/load, backend resolution.
  • src/runtime/runtime.{cc,h} — new FFI registrations and CUDA access‑policy behavior.

Possibly related PRs

Poem

🐰 Hop, a change in every nook,

TVM‑FFI packed into my book.
NULL guards snug, adapters cheer,
Hosts and devices split — hooray, my dear!
Carrots of kernels, now near.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 16.09% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title "[FFI] Use tvm ffi as the default execution backend" clearly and specifically describes the main change - switching the default execution backend from cython to tvm_ffi across the codebase.

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2e6f110 and 7c0b8cd.

📒 Files selected for processing (6)
  • examples/quickstart.py (1 hunks)
  • tilelang/autotuner/tuner.py (4 hunks)
  • tilelang/jit/adapter/ctypes/adapter.py (6 hunks)
  • tilelang/jit/adapter/cython/adapter.py (6 hunks)
  • tilelang/profiler/__init__.py (1 hunks)
  • tilelang/utils/tensor.py (0 hunks)

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
tilelang/jit/__init__.py (1)

46-65: execution_backend types and docs are out of sync with actual behavior

Across compile, par_compile, and jit:

  • The Literal[...] annotations still mention "dlpack" and omit "tvm_ffi" (the new default) and "torch" (required in the metal path), e.g.:

    • execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "tvm_ffi"
  • Docstrings still say “Defaults to \"cython\"”, while the code now defaults to "tvm_ffi".

This will confuse users and static type checkers and hides that "torch" is required for metal targets (see the is_metal_target assertion).

Consider:

  • Updating the Literals everywhere execution_backend appears to the actual supported set (e.g. including "tvm_ffi" and "torch" and dropping "dlpack" if it’s really gone), and
  • Fixing the docstrings to match the new default "tvm_ffi".

Also applies to: 98-117, 169-173, 284-305, 298-327

tilelang/autotuner/tuner.py (1)

574-580: Minor issues in cache-saving condition and autotune docstring

  • The condition

    if self.compile_args.execution_backend in ("torch"):

    works but is unconventional: ("torch") is just a string, so this is a substring check. For clarity (and to avoid surprises) it should be == "torch" or in ("torch",).

  • The warning message still says:

    "DLPack backend does not support cache saving to disk."

    but the condition now checks for "torch". That’s misleading; update the text to mention the torch backend instead (or whatever backend is actually being special-cased).

  • In the autotune docstring, the documented parameter

    execution_backend : Literal["tvm_ffi", "ctypes", "cython"], optional
    

    does not exist in the function signature. This should either be removed or the API extended to actually accept execution_backend.

Also applies to: 711-713

tilelang/jit/adapter/nvrtc/adapter.py (1)

93-104: Fix NVRTC get_kernel_source signature to match JITKernel usage

JITKernel.get_kernel_source() now calls self.adapter.get_kernel_source(kernel_only=kernel_only) for "nvrtc", but NVRTCKernelAdapter.get_kernel_source takes no kernel_only parameter. This will raise a TypeError whenever get_kernel_source() is used with the NVRTC backend (e.g., PTX/SASS generation).

You can make NVRTC conform to the common adapter API by adding a kernel_only flag (even if it’s ignored for now), keeping current behavior:

-    def get_kernel_source(self) -> str | None:
-        """Get the CUDA kernel source code.
+    def get_kernel_source(self, kernel_only: bool = False) -> str | None:
+        """Get the CUDA kernel source code.
+
+        Parameters
+        ----------
+        kernel_only : bool, optional
+            Ignored for NVRTC; only device source is available.

         Returns
         -------
         Optional[str]
             The kernel source code, or None if not available
         """
-        return self.device_kernel_source
+        return self.device_kernel_source

Also, in from_database, host_kernel_source and pass_configs are currently only stored and not used. That’s fine for correctness, but if they’re intentionally unused you might either wire them into reconstruction logic (for symmetry with other backends) or mark them with a leading underscore to silence linters.

Also applies to: 171-179

tilelang/jit/kernel.py (1)

602-603: Fix kernel_source property for from_database JITKernel instances

The kernel_source property still falls back to self.adapter.kernel_global_source when self.artifact is falsy:

def kernel_source(self) -> str:
    return self.artifact.kernel_source if self.artifact else self.adapter.kernel_global_source

All adapters in this PR have moved away from kernel_global_source in favor of host_kernel_source/device_kernel_source and get_kernel_source(...). For kernels created via JITKernel.from_database, self.artifact is None, so this property will raise AttributeError on self.adapter.kernel_global_source.

A minimal fix that matches the new adapter API would be:

     @property
     def kernel_source(self) -> str:
-        return self.artifact.kernel_source if self.artifact else self.adapter.kernel_global_source
+        if self.artifact:
+            return self.artifact.kernel_source
+        # Fall back to adapter, keeping the historical “kernel only” behavior.
+        return self.adapter.get_kernel_source(kernel_only=True)

This keeps existing behavior for freshly compiled kernels and makes from_database instances work with the updated adapters.

♻️ Duplicate comments (2)
examples/gdn/test_example_gdn_compilation.py (1)

200-201: Test harness inconsistency across the PR.

Same issue as in test_example_blocksparse_gemm.py—this test now bypasses the centralized test runner. Please see the earlier comment about ensuring consistency in test harness changes across the PR.

testing/python/dynamic/test_tilelang_dynamic_symbolic.py (1)

517-517: Test harness consistency note.

This test now uses tilelang.testing.main(), which is the opposite pattern from test_example_blocksparse_gemm.py and test_example_gdn_compilation.py. See the earlier comment about test harness consistency.

🧹 Nitpick comments (5)
examples/blocksparse_gemm/example_blocksparse_gemm.py (1)

169-169: Consider making kernel source printing optional.

The unconditional print of the entire kernel source may produce verbose console output. For better user experience, consider making this conditional on a --verbose or --debug flag.

Apply this diff to add a verbose flag:

 parser.add_argument(
     "--use_autotune", action="store_true", default=False, help="Whether to use autotune")
+parser.add_argument(
+    "--verbose", action="store_true", default=False, help="Print kernel source code")

Then make the print conditional:

         block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K
         print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})")
-    print(kernel.get_kernel_source())
+    if args.verbose:
+        print(kernel.get_kernel_source())
testing/python/debug/test_tilelang_debug_print.py (1)

16-17: Global cache disabling may leak outside this test module

tilelang.disable_cache() is a global switch, so once this runs, caching stays disabled for the rest of the process. If that’s intentional (to avoid any cached kernels affecting debug-print behavior across the entire test run), this is fine; otherwise, consider scoping it (e.g., re‑enable after the test or using a context helper) so other tests keep exercising the caching path.

tilelang/jit/adapter/base.py (1)

7-8: Stream/device helper thunks are reasonable; consider narrowing broad excepts or documenting them

The new get_current_stream_functor and get_current_device_functor helpers are a useful abstraction: they respect PyTorch’s current device/stream at call time and degrade to CPU/0 when CUDA isn’t available.

Both helpers swallow all Exceptions and silently fall back:

except Exception:
    ...

Given these are just helpers, that’s defensible, but static analysis (BLE001) will complain. If you want to keep the behavior but quiet the linter, consider either:

  • Catching the specific expected exceptions (e.g. RuntimeError, AttributeError), or
  • Adding an explicit inline # noqa: BLE001 with a short comment explaining why a broad catch is intentional.

Also applies to: 50-71, 72-88

tilelang/jit/kernel.py (1)

54-76: Align execution_backend docs and verify non–tvm_ffi codegen gating

  • The constructor now defaults execution_backend to "tvm_ffi" (Line 58), but the docstring still claims the default is "cython" (Line 76). That’s misleading for users and should be updated.
  • In _compile_and_create_adapter, both enable_host_codegen and enable_device_compile are true only when execution_backend == "tvm_ffi" (Lines 227‑230). The ctypes/cython/nvrtc branches still pass artifact.host_mod and artifact.device_mod into their adapters, so please confirm tilelang.lower(...) still produces suitable host_mod / device_mod for those backends under this gating.

If the non‑tvm_ffi backends do require host/device IR modules, you may want something like:

-        enable_host_codegen = execution_backend == "tvm_ffi"
-        enable_device_compile = execution_backend == "tvm_ffi"
+        enable_host_codegen = execution_backend in {"tvm_ffi", "ctypes", "cython", "nvrtc"}
+        enable_device_compile = execution_backend in {"tvm_ffi", "ctypes", "cython", "nvrtc"}

(or whichever subset actually needs them).

Also applies to: 227-236, 241-297

tilelang/cache/kernel_cache.py (1)

41-45: Tighten backend selection and exception logging in KernelCache

Functionally the new tvm_ffi caching path looks fine, but two small robustness issues stand out:

  1. Backend source selection in _load_kernel_from_disk

    _load_kernel_from_disk takes an execution_backend argument (Line 351) but chooses kernel_lib_path based on self.execution_backend (Lines 376‑381). Since self.execution_backend is mutable state set in _generate_key, this is fragile if a single KernelCache instance is used with different backends concurrently.

    Consider basing the path selection on the method argument instead:

  •    if self.execution_backend == "nvrtc":
    
  •    if execution_backend == "nvrtc":
           kernel_lib_path = KERNEL_CUBIN_PATH
    
  •    elif self.execution_backend == "tvm_ffi":
    
  •    elif execution_backend == "tvm_ffi":
           kernel_lib_path = EXECUTABLE_PATH
       else:
           kernel_lib_path = KERNEL_LIB_PATH
    
    
    
  1. Bare Exception catches with logger.error

    The various try/except blocks around disk I/O (_save_kernel_to_disk and _load_kernel_from_disk) catch bare Exception and log via self.logger.error(...) (e.g., Lines 278‑280, 294‑296, 397‑399, 405‑407). For debugging cache corruption or filesystem issues, it would be more useful to:

    • Narrow the exception type where reasonable (e.g., OSError/IOError), and/or
    • Use logger.exception(...) to preserve the traceback.

    For example:

  •    except Exception as e:
    
  •        self.logger.error(f"Error saving kernel source code to disk: {e}")
    
  •    except OSError:
    
  •        self.logger.exception("Error saving kernel source code to disk")
    
    
    Similar adjustments can be applied to the other I/O blocks.
    
    

These are non‑blocking but will make cache behavior easier to reason about and debug.

Also applies to: 69-79, 240-245, 270-299, 345-383, 392-407

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between eac96cd and dacab10.

📒 Files selected for processing (31)
  • 3rdparty/tvm (1 hunks)
  • examples/blocksparse_gemm/example_blocksparse_gemm.py (1 hunks)
  • examples/blocksparse_gemm/test_example_blocksparse_gemm.py (1 hunks)
  • examples/gdn/example_chunk_o_bwd.py (0 hunks)
  • examples/gdn/test_example_gdn_compilation.py (3 hunks)
  • src/runtime/runtime.cc (2 hunks)
  • src/target/codegen_cpp.cc (2 hunks)
  • src/transform/arg_binder.cc (5 hunks)
  • src/transform/lower_hopper_intrin.cc (1 hunks)
  • src/transform/make_packed_api.cc (13 hunks)
  • src/transform/simplify.cc (1 hunks)
  • testing/python/debug/test_tilelang_debug_print.py (1 hunks)
  • testing/python/dynamic/test_tilelang_dynamic_symbolic.py (1 hunks)
  • testing/python/jit/test_tilelang_jit_gemm_ctypes.py (0 hunks)
  • testing/python/language/test_tilelang_language_alloc.py (1 hunks)
  • tilelang/autotuner/param.py (2 hunks)
  • tilelang/autotuner/tuner.py (3 hunks)
  • tilelang/cache/__init__.py (1 hunks)
  • tilelang/cache/kernel_cache.py (11 hunks)
  • tilelang/contrib/dlpack.py (0 hunks)
  • tilelang/jit/__init__.py (4 hunks)
  • tilelang/jit/adapter/__init__.py (1 hunks)
  • tilelang/jit/adapter/base.py (2 hunks)
  • tilelang/jit/adapter/ctypes/adapter.py (6 hunks)
  • tilelang/jit/adapter/cython/adapter.py (6 hunks)
  • tilelang/jit/adapter/dlpack.py (0 hunks)
  • tilelang/jit/adapter/nvrtc/adapter.py (5 hunks)
  • tilelang/jit/adapter/tvm_ffi.py (1 hunks)
  • tilelang/jit/kernel.py (16 hunks)
  • tilelang/profiler/__init__.py (1 hunks)
  • tilelang/utils/tensor.py (0 hunks)
💤 Files with no reviewable changes (5)
  • tilelang/utils/tensor.py
  • testing/python/jit/test_tilelang_jit_gemm_ctypes.py
  • examples/gdn/example_chunk_o_bwd.py
  • tilelang/contrib/dlpack.py
  • tilelang/jit/adapter/dlpack.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.080Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.080Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • examples/blocksparse_gemm/test_example_blocksparse_gemm.py
  • testing/python/debug/test_tilelang_debug_print.py
🧬 Code graph analysis (13)
tilelang/jit/adapter/base.py (5)
tilelang/jit/adapter/tvm_ffi.py (2)
  • func (195-262)
  • get_kernel_source (311-316)
tilelang/jit/adapter/ctypes/adapter.py (1)
  • get_kernel_source (294-300)
tilelang/jit/adapter/cython/adapter.py (1)
  • get_kernel_source (384-391)
tilelang/jit/adapter/nvrtc/adapter.py (1)
  • get_kernel_source (171-179)
tilelang/jit/kernel.py (1)
  • get_kernel_source (421-432)
tilelang/jit/adapter/__init__.py (1)
tilelang/jit/adapter/tvm_ffi.py (1)
  • TVMFFIKernelAdapter (23-321)
examples/gdn/test_example_gdn_compilation.py (2)
tilelang/env.py (1)
  • disable_cache (275-276)
examples/gdn/example_chunk_o_bwd.py (2)
  • kernel (157-395)
  • tilelang_chunk_o_bwd_dqkwg (115-397)
testing/python/language/test_tilelang_language_alloc.py (6)
tilelang/jit/adapter/ctypes/adapter.py (1)
  • get_kernel_source (294-300)
tilelang/jit/adapter/tvm_ffi.py (1)
  • get_kernel_source (311-316)
tilelang/jit/adapter/base.py (1)
  • get_kernel_source (93-97)
tilelang/jit/adapter/cython/adapter.py (1)
  • get_kernel_source (384-391)
tilelang/jit/adapter/nvrtc/adapter.py (1)
  • get_kernel_source (171-179)
tilelang/jit/kernel.py (1)
  • get_kernel_source (421-432)
testing/python/debug/test_tilelang_debug_print.py (2)
tilelang/env.py (1)
  • disable_cache (275-276)
tilelang/jit/__init__.py (2)
  • compile (46-95)
  • compile (222-248)
tilelang/jit/adapter/nvrtc/adapter.py (1)
tilelang/jit/adapter/wrapper.py (8)
  • host_func (703-713)
  • wrap (202-203)
  • wrap (1203-1222)
  • wrap (1230-1243)
  • update_lib_code (622-671)
  • update_lib_code (897-917)
  • update_lib_code (1116-1135)
  • update_lib_code (1168-1170)
examples/blocksparse_gemm/example_blocksparse_gemm.py (7)
tilelang/jit/adapter/ctypes/adapter.py (1)
  • get_kernel_source (294-300)
tilelang/jit/adapter/tvm_ffi.py (1)
  • get_kernel_source (311-316)
tilelang/jit/adapter/base.py (1)
  • get_kernel_source (93-97)
tilelang/jit/adapter/cython/adapter.py (1)
  • get_kernel_source (384-391)
tilelang/jit/adapter/nvrtc/adapter.py (1)
  • get_kernel_source (171-179)
tilelang/jit/kernel.py (1)
  • get_kernel_source (421-432)
tilelang/jit/param.py (1)
  • get_kernel_source (27-28)
tilelang/jit/adapter/cython/adapter.py (5)
tilelang/jit/adapter/wrapper.py (7)
  • wrap (202-203)
  • wrap (1203-1222)
  • wrap (1230-1243)
  • update_lib_code (622-671)
  • update_lib_code (897-917)
  • update_lib_code (1116-1135)
  • update_lib_code (1168-1170)
tilelang/jit/adapter/ctypes/adapter.py (1)
  • get_kernel_source (294-300)
tilelang/jit/adapter/tvm_ffi.py (1)
  • get_kernel_source (311-316)
tilelang/jit/adapter/nvrtc/adapter.py (1)
  • get_kernel_source (171-179)
tilelang/jit/kernel.py (2)
  • get_kernel_source (421-432)
  • params (598-599)
tilelang/jit/adapter/ctypes/adapter.py (2)
tilelang/jit/kernel.py (1)
  • params (598-599)
tilelang/jit/adapter/base.py (1)
  • _legalize_result_idx (20-44)
tilelang/jit/kernel.py (5)
tilelang/jit/adapter/nvrtc/adapter.py (3)
  • NVRTCKernelAdapter (25-259)
  • from_database (94-147)
  • get_kernel_source (171-179)
tilelang/jit/adapter/tvm_ffi.py (4)
  • TVMFFIKernelAdapter (23-321)
  • from_database (267-297)
  • get_kernel_source (311-316)
  • get_host_source (299-303)
tilelang/jit/adapter/base.py (2)
  • BaseKernelAdapter (10-100)
  • get_kernel_source (93-97)
tilelang/jit/adapter/ctypes/adapter.py (3)
  • from_database (113-162)
  • CtypesKernelAdapter (19-300)
  • get_kernel_source (294-300)
tilelang/jit/adapter/cython/adapter.py (2)
  • from_database (149-208)
  • get_kernel_source (384-391)
src/transform/make_packed_api.cc (2)
src/target/codegen_cpp.cc (8)
  • VisitExpr_ (166-179)
  • VisitExpr_ (166-167)
  • VisitExpr_ (375-414)
  • VisitExpr_ (375-376)
  • VisitExpr_ (452-455)
  • VisitExpr_ (452-453)
  • VisitExpr_ (457-460)
  • VisitExpr_ (457-458)
tilelang/language/tir/op.py (1)
  • tvm_struct_get (414-436)
tilelang/jit/adapter/tvm_ffi.py (3)
tilelang/utils/target.py (1)
  • determine_target (62-123)
tilelang/jit/adapter/base.py (6)
  • BaseKernelAdapter (10-100)
  • _legalize_result_idx (20-44)
  • _post_init (99-100)
  • _convert_torch_func (47-48)
  • get_current_device_functor (73-88)
  • get_kernel_source (93-97)
tilelang/engine/param.py (1)
  • KernelParam (12-103)
tilelang/cache/kernel_cache.py (2)
tilelang/jit/kernel.py (5)
  • export_library (609-626)
  • kernel_source (602-603)
  • get_host_source (434-439)
  • get_kernel_source (421-432)
  • from_database (143-183)
tilelang/jit/adapter/tvm_ffi.py (4)
  • func (195-262)
  • get_host_source (299-303)
  • get_kernel_source (311-316)
  • from_database (267-297)
🪛 Ruff (0.14.4)
tilelang/jit/adapter/base.py

66-66: Do not catch blind exception: Exception

(BLE001)


85-85: Do not catch blind exception: Exception

(BLE001)

tilelang/jit/adapter/__init__.py

2-2: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/jit/adapter/nvrtc/adapter.py

103-103: Unused class method argument: pass_configs

(ARG003)

tilelang/jit/adapter/tvm_ffi.py

176-176: Do not catch blind exception: Exception

(BLE001)


199-201: Avoid specifying long messages outside the exception class

(TRY003)


221-221: Unpacked variable ref_id is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


233-235: Avoid specifying long messages outside the exception class

(TRY003)


251-253: Avoid specifying long messages outside the exception class

(TRY003)


277-277: Unused class method argument: compile_flags

(ARG003)

tilelang/cache/kernel_cache.py

278-278: Do not catch blind exception: Exception

(BLE001)


279-279: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


294-294: Do not catch blind exception: Exception

(BLE001)


295-295: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


397-397: Do not catch blind exception: Exception

(BLE001)


398-398: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


405-405: Do not catch blind exception: Exception

(BLE001)


406-406: Use logging.exception instead of logging.error

Replace with exception

(TRY400)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (20)
src/transform/arg_binder.cc (4)

120-122: LGTM: Clear error messaging for dtype mismatches.

The error message clearly indicates both expected and provided types, making debugging easier.


192-196: LGTM: Improved error messages avoid runtime expression dumping.

The error messages are now more user-friendly by avoiding TIR expression dumps and providing clear, actionable feedback about NULL pointers, ndim mismatches, and dtype incompatibility.

Also applies to: 220-223, 227-231


342-344: LGTM: More descriptive stride error message.

Explicitly mentioning "non-compact strides" makes the error message clearer for debugging.


99-119: Based on my comprehensive verification, I need to mark this review comment as incorrect.

The FP8 variant compatibility logic in the lambda is intentionally asymmetric and correct.

The concern about bidirectional compatibility misunderstands the purpose of the dtype_compatible lambda. E4M3FN and E4M3FNUZ use different exponent biases and differ in the special values they can represent—these are not interchangeable formats.

The current logic deliberately allows binding when the expected type is the base variant (e.g., float8_e4m3), but rejects binding when expected is a variant (e.g., float8_e4m3fn expected, float8_e4m3 provided). This asymmetry is intentional protection: conversion between variants like e4m3fn and e4m3fnuz requires explicit normalization, which cannot be silently handled during buffer binding.

The codebase treats each FP8 variant as a distinct type code with different conversion requirements. Bidirectional compatibility would mask necessary type conversions and introduce silent correctness issues.

Likely an incorrect or invalid review comment.

3rdparty/tvm (1)

1-1: Verify TVM submodule commit is compatible with FFI migration changes.

The TVM commit (cdc2aced0d87cc6e5e24811bd964efd2ce2d0729) is valid and exists in the apache/tvm repository. However, the commit message indicates it focuses on CUDA function attribute refactoring and error handling in C host code generation. Confirm this commit integrates properly with the TVMFFIKernelAdapter and FFI backend adoption across the codebase, particularly verifying there are no breaking changes to the TVM FFI wrapper layer.

testing/python/language/test_tilelang_language_alloc.py (1)

154-154: LGTM: Kernel-only source retrieval is appropriate here.

The change to kernel_only=True correctly retrieves only the device kernel source, which is sufficient for validating initializer constants. This aligns with the PR's separation of host and device kernel sources.

tilelang/profiler/__init__.py (1)

278-278: LGTM: Direct tensor passing aligns with TVM FFI backend.

The removal of adapt_torch2tvm conversion correctly reflects that the new TVM FFI backend can handle PyTorch tensors directly, simplifying the profiling path.

examples/gdn/test_example_gdn_compilation.py (2)

110-110: Cache disabling is appropriate for compilation tests.

Calling tilelang.disable_cache() ensures this test exercises the actual compilation path rather than retrieving cached results, which is correct for a compilation test.


122-127: Clarify the purpose of redundant kernel instantiations.

The same kernel is instantiated three times with identical parameters. Is this intentionally testing compilation idempotency or robustness, or is this leftover debug code?

Consider adding a comment explaining the intent, or removing the redundant instantiations if they're not needed:

-    kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
-                                        gate_dtype, state_dtype, chunk_size, 1.0, use_g, True,
-                                        block_DK, block_DV, threads, num_stages)
-    kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
-                                        gate_dtype, state_dtype, chunk_size, 1.0, use_g, True,
-                                        block_DK, block_DV, threads, num_stages)
-
+    # Test compilation idempotency by instantiating the kernel multiple times
+    for _ in range(3):
+        kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
+                                            gate_dtype, state_dtype, chunk_size, 1.0, use_g, True,
+                                            block_DK, block_DV, threads, num_stages)
tilelang/autotuner/param.py (2)

45-45: LGTM: Default backend updated to tvm_ffi.

The change from "cython" to "tvm_ffi" as the default execution backend aligns with the PR's objective of making TVM FFI the default execution backend.


211-211: LGTM: Consistent backend parameter update.

The signature change in _load_kernel_from_disk correctly mirrors the updated default backend throughout the codebase.

tilelang/jit/adapter/__init__.py (2)

2-2: LGTM: Adapter imports updated for TVM FFI backend.

The addition of TVMFFIKernelAdapter and removal of the legacy DLPack adapter correctly reflects the migration to the new TVM FFI execution backend.

Note: The Ruff hint about the unused noqa directive is a false positive—F401 suppressions in __init__.py files are intentional for re-exports.


6-6: LGTM: MetalKernelAdapter added.

The addition of MetalKernelAdapter expands backend support as part of the adapter refactoring.

tilelang/cache/__init__.py (1)

21-21: LGTM: Cache function updated to use tvm_ffi backend.

The default execution backend change from "cython" to "tvm_ffi" is consistent with the broader migration across the codebase and aligns with the PR's objectives.

examples/blocksparse_gemm/test_example_blocksparse_gemm.py (1)

9-10: All 6 files with commented tilelang.testing.main() follow an identical pattern:

Every single file (across both testing/ and examples/) has replaced the centralized runner with a direct test function call in __main__. This uniformity suggests the pattern is intentional, not accidental, but raises a critical question: Why are these 6 tests deliberately bypassing the centralized test harness while 100+ others use it?

The original review comment is valid. The inconsistency is real and systematic. However, the fact that all 6 deviate identically could indicate:

  1. Incomplete refactoring – This may be work-in-progress where some tests have been converted but others haven't
  2. Intentional exemption – These tests may require different handling for a documented reason
  3. Test discovery risk – If tilelang.testing.main() handles centralized registration/discovery for CI/CD, this bypass could cause tests to be skipped

The original review appropriately asks for clarification. The file under review is part of a systematic pattern, not an isolated anomaly.

Rewritten Review Comment:

Clarify the purpose of bypassing tilelang.testing.main() across these 6 tests.

This file is one of exactly 6 tests (out of 100+) that replace tilelang.testing.main() with direct function invocation. All 6 follow an identical pattern, suggesting intent rather than accident. Is this a deliberate refactoring in progress, or are these tests intentionally exempt from centralized test discovery? Without clarification, this risks CI/CD test discovery issues and maintenance confusion. If intentional, document the reason; if unintended, restore the centralized runner to match the project standard.

src/transform/simplify.cc (1)

243-276: Optional buffer-parameter pruning logic looks sound

The new simplify_arguments path correctly:

  • Recomputes used_buffers_ from the (already simplified) func.
  • Retains buffer params when they’re used in the body or referenced from other buffer definitions (used_in_buffer_def_).
  • Keeps scalar params unconditionally.
  • Only rebuilds the PrimFunc when a parameter was actually dropped.

This matches the intent of removing only truly unused buffer parameters without disturbing scalar arguments.

src/target/codegen_cpp.cc (1)

206-216: TVMFFIAny migration in codegen is consistent

Using TVMFFIAny for ret_val, stack allocation, and the casts passed into TVMFuncCall is consistent with the FFI aliases, and the static_assert on alignof(TVMFFIAny) vs DLTensor should keep the tvm_stack_alloca("array", ...) reinterpret-cast safe. I don’t see correctness issues in these changes.

Please make sure the updated codegen paths are covered by an FFI-based end‑to‑end test (e.g., a small packed function call exercising tvm_stack_alloca and return values) to catch any ABI/alignment regressions early.

Also applies to: 231-243, 382-398

src/runtime/runtime.cc (1)

91-109: Canonical FFI symbol registration is good; double‑check logging macro

Registering the TMA helpers with the canonical names (tl::tvm_tensormap_create_tiled / tl::tvm_tensormap_create_im2col) instead of string literals is a nice improvement for API consistency.

The failure path now uses LOG_FATAL << .... Please verify that LOG_FATAL is a valid logging macro in this project; most TVM-style code uses LOG(FATAL) << .... If LOG_FATAL isn’t defined, this will be a compile error and should be changed back to LOG(FATAL).

Also applies to: 185-201

tilelang/autotuner/tuner.py (1)

140-167: Updated default execution_backend for autotuner compile args looks consistent

Switching set_compile_args to

execution_backend: Literal["tvm_ffi", "ctypes", "cython"] = "tvm_ffi"

matches the broader move to tvm_ffi as the default and aligns with the updated CompileArgs typing. No issues here.

tilelang/jit/adapter/cython/adapter.py (1)

52-54: Host/device source split in Cython adapter looks consistent

The new host_kernel_source / device_kernel_source fields, initialization via wrapper.wrap(self.get_kernel_source(kernel_only=True)), and get_kernel_source(kernel_only=...) behavior all line up with the updated JITKernel and KernelCache expectations. No issues from a correctness standpoint.

Also applies to: 81-97, 121-128, 148-165, 384-391

Comment on lines +242 to +289
PrimExpr cond = (v_type_code == expect_code && v_type_bits == expect_bits &&
v_type_lanes == expect_lanes);

// Allow float8_e4m3 to match float8_e4m3fn/float8_e4m3fnuz at runtime.
if (buffer->dtype.is_float8_e4m3()) {
PrimExpr code_e4m3 = IntImm(DataType::UInt(8), DataType::kFloat8_e4m3);
PrimExpr code_e4m3fn = IntImm(DataType::UInt(8), DataType::kFloat8_e4m3fn);
PrimExpr code_e4m3fnuz =
IntImm(DataType::UInt(8), DataType::kFloat8_e4m3fnuz);
PrimExpr code_match =
(v_type_code == code_e4m3 || v_type_code == code_e4m3fn ||
v_type_code == code_e4m3fnuz);
cond = cond || (code_match && v_type_bits == expect_bits &&
v_type_lanes == expect_lanes);
}
// Allow float8_e5m2 to match float8_e5m2fnuz at runtime.
if (buffer->dtype.is_float8_e5m2()) {
PrimExpr code_e5m2 = IntImm(DataType::UInt(8), DataType::kFloat8_e5m2);
PrimExpr code_e5m2fnuz =
IntImm(DataType::UInt(8), DataType::kFloat8_e5m2fnuz);
PrimExpr code_match =
(v_type_code == code_e5m2 || v_type_code == code_e5m2fnuz);
cond = cond || (code_match && v_type_bits == expect_bits &&
v_type_lanes == expect_lanes);
}
// Allow bool to match int8/uint8 at runtime, and also kDLBool(code=6).
if (buffer->dtype.is_bool()) {
PrimExpr code_int = IntImm(DataType::UInt(8), DataType::kInt);
PrimExpr code_uint = IntImm(DataType::UInt(8), DataType::kUInt);
PrimExpr code_kdlbool = IntImm(DataType::UInt(8), 6);
PrimExpr bits8 = IntImm(DataType::UInt(8), 8);
PrimExpr bits1 = IntImm(DataType::UInt(8), 1);
PrimExpr lanes_ok = (v_type_lanes == expect_lanes);
PrimExpr int8_ok =
(v_type_code == code_int && v_type_bits == bits8 && lanes_ok);
PrimExpr uint8_ok =
(v_type_code == code_uint && v_type_bits == bits8 && lanes_ok);
// Some frontends may tag bool tensors as kDLBool(code=6), commonly with
// bits=8 or bits=1.
PrimExpr kdlbool8_ok =
(v_type_code == code_kdlbool && v_type_bits == bits8 && lanes_ok);
PrimExpr kdlbool1_ok =
(v_type_code == code_kdlbool && v_type_bits == bits1 && lanes_ok);
// Also accept any dtype whose bitwidth=1, regardless of code, to be
// defensive.
PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok);
cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Review overly permissive bool compatibility logic.

The bool compatibility logic at lines 286-288 accepts any dtype with 1-bit width regardless of type code:

PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok);
cond = cond || ... || bit1_ok;

This defensive check is overly broad and could mask type errors by accepting exotic 1-bit types (e.g., hypothetical float1 or custom types) that may not have correct bool semantics. Consider restricting this to only the explicitly listed type codes (int8, uint8, kDLBool with 8/1 bits) rather than accepting arbitrary 1-bit types.

Apply this diff to remove the overly permissive check:

     PrimExpr kdlbool1_ok =
         (v_type_code == code_kdlbool && v_type_bits == bits1 && lanes_ok);
-    // Also accept any dtype whose bitwidth=1, regardless of code, to be
-    // defensive.
-    PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok);
-    cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok;
+    cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok;
🤖 Prompt for AI Agents
In src/transform/arg_binder.cc around lines 242 to 289, the current bool
compatibility logic defines and uses a permissive bit1_ok check that accepts any
1-bit dtype regardless of type code; remove that overly broad check by deleting
the PrimExpr bit1_ok definition and stop OR-ing bit1_ok into cond so only the
explicit allowed codes (int8, uint8, and kDLBool with bits=8 or bits=1) combined
with lanes_ok are accepted.

Comment on lines 31 to +32
Call alloc_desc = Call(DataType::Handle(), builtin::tvm_stack_alloca(),
{StringImm("arg_value"), 16});
{StringImm("tvm_ffi_any"), 16});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify consistency and allocation size across the codebase.

The type string change from "arg_value" to "tvm_ffi_any" aligns with the PR's FFI migration. However, please ensure that:

  1. All tvm_stack_alloca calls have been consistently updated across the codebase
  2. The size of TVMFFIAny is 8 bytes (same as TVMValue) to maintain the 128-byte allocation (16 elements × 8 bytes)

Run the following script to check for any remaining "arg_value" usages in stack allocations:

Additionally, please search for documentation on TVMFFIAny size to confirm the allocation correctness:


🏁 Script executed:

#!/bin/bash
# Search for any remaining tvm_stack_alloca calls that might still use "arg_value"
rg -nP 'tvm_stack_alloca\s*\(' -A2 -B1

Length of output: 1595


🌐 Web query:

What is the size of TVMFFIAny in TVM FFI?

💡 Result:

TVMFFIAny is 16 bytes (128 bits) in size. [1][2]


I need to verify the intended allocation size for the TensorMap descriptor before finalizing the assessment.

Also, let me search for more information about TVMFFIAny and TVMValue size comparison:


🏁 Script executed:

#!/bin/bash
# Search for TensorMap descriptor size references and documentation
rg -nP 'TensorMap|tvm_ffi_any' -C3 | head -100

Length of output: 6827


🏁 Script executed:

#!/bin/bash
# Check for any size constants or comments related to descriptor allocation
rg -nP '128|256' src/transform/lower_hopper_intrin.cc -C2

Length of output: 321


🌐 Web query:

TVMFFIAny vs TVMValue size bytes TVM

💡 Result:

  • TVMFFIAny: 16 bytes (128-bit). [1]
  • TVMValue: 8 bytes on typical 64-bit platforms (it's a union whose payload fits in 64 bits). [2]

Sources: TVM-FFI docs for TVMFFIAny and TVM C API TVMValue. [1] https://tvm.apache.org/ffi/reference/cpp/generated/structTVMFFIAny.html [2] https://tvm.apache.org/docs/reference/api/doxygen/unionTVMValue.html


🌐 Web query:

TensorMap descriptor size bytes TVM CUDA Hopper

💡 Result:

The TensorMap descriptor is 1024 bits = 128 bytes. PTX documents the tensor-map object size as 1024 bits (b1024). The CUDA docs also describe tensor-map objects and alignment requirements (CUDA APIs require the tensorMap address alignment; some docs state 64‑byte alignment). [1][2][3]

Sources:

  • PTX / Hopper ISA (tensormap .b1024 = 1024 bits). [1]
  • CUDA Driver API tensor-map documentation (descriptor and alignment notes). [2]
  • NVIDIA patent / implementation notes (describes 128‑byte opaque tensor map). [3]

Critical allocation size error: Change 16 elements to 8 elements for TVMFFIAny migration.

The migration from TVMValue (8 bytes) to TVMFFIAny (16 bytes) failed to account for the size difference. TVMFFIAny is 16 bytes while TVMValue is 8 bytes, and the TensorMap descriptor requires 128 bytes (1024 bits).

The current code allocates:

  • 16 elements × 16 bytes = 256 bytes

It should allocate:

  • 8 elements × 16 bytes = 128 bytes

Fix at src/transform/lower_hopper_intrin.cc:32:

{StringImm("tvm_ffi_any"), 8}  // Changed from 16 to 8

This double over-allocation will cause memory corruption and crashes during TensorMap descriptor initialization on Hopper GPUs.

🤖 Prompt for AI Agents
In src/transform/lower_hopper_intrin.cc around lines 31 to 32, the allocation
call wrongly requests 16 elements of TVMFFIAny (16 bytes each) producing 256
bytes instead of the required 128 bytes for the TensorMap descriptor; change the
element count from 16 to 8 so the Call allocates {StringImm("tvm_ffi_any"), 8}
(8 × 16 = 128 bytes) to match the expected descriptor size and avoid
over-allocation and potential memory corruption.

Comment on lines +93 to +97
def get_kernel_source(self, kernel_only: bool = True) -> str:
if kernel_only:
return self.mod.imports[0].inspect_source()
else:
return self.mod.inspect_source() + "\n\n" + self.mod.imports[0].inspect_source()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify get_kernel_source semantics and adapter interface consistency

Changing the base adapter to:

def get_kernel_source(self, kernel_only: bool = True) -> str:
    if kernel_only:
        return self.mod.imports[0].inspect_source()
    else:
        return self.mod.inspect_source() + "\n\n" + self.mod.imports[0].inspect_source()

is consistent with the idea of “kernel-only vs host+device” sources, but two things are worth double‑checking:

  1. Module structure assumptions: this assumes every adapter’s mod always has at least one import and that imports[0] is the device module. If that’s not guaranteed by your codegen, this can raise or return the wrong source. If there are backends where mod.imports may be empty or ordered differently, they should override get_kernel_source explicitly.

  2. Adapter interface alignment: JITKernel.get_kernel_source (in tilelang/jit/kernel.py) calls self.adapter.get_kernel_source(kernel_only=kernel_only) for backends including "ctypes", "cython", "nvrtc", and "tvm_ffi". Ensure every adapter implements a compatible signature:

    • ctypes/cython/tvm_ffi already provide get_kernel_source(kernel_only: bool = False).
    • The nvrtc adapter snippet you provided currently has def get_kernel_source(self) -> str | None: with no kernel_only parameter. If that’s still the case in this PR, it will raise TypeError: get_kernel_source() got an unexpected keyword argument 'kernel_only' when execution_backend="nvrtc".

Either update the nvrtc adapter (and any other stragglers) to accept kernel_only: bool = False, or special‑case "nvrtc" in JITKernel.get_kernel_source so it calls the old signature.


I'll verify the critical concerns identified in the review comment by examining the adapter implementations and their signatures.

[/run_scripts]

Let me search for the adapter files in the repository structure:

[/run_scripts]

Let me be more specific and search for the key files mentioned in the review:

[/run_scripts]

Let me take a different approach and search for the get_kernel_source method across the codebase:

[/run_scripts]

Let me also check for the JITKernel class and its get_kernel_source implementation:

[/run_scripts]

Let me search more broadly for adapter-related files:

[/run_scripts]

Now let me examine the specific files mentioned in the review:

[/run_scripts]

budget:token_budget200000</budget:token_output>


🏁 Script executed:

#!/bin/bash
# Find adapter implementations and inspect get_kernel_source signatures
find . -type f -name "*.py" -path "*/adapter/*" | head -20

Length of output: 524


🏁 Script executed:

#!/bin/bash
# Locate adapter files and get_kernel_source method implementations
fd -t f "\.py$" "adapter" | head -30

Length of output: 94


🏁 Script executed:

#!/bin/bash
# Look for base.py, kernel.py and various adapter implementations
fd -t f "base\.py" 
fd -t f "kernel\.py"
fd -t f "nvrtc.*\.py"
fd -t f "ctypes.*\.py"

Length of output: 281


🏁 Script executed:

#!/bin/bash
# Search for get_kernel_source method definitions
rg -n "def get_kernel_source" --type py

Length of output: 646


🏁 Script executed:

#!/bin/bash
# Search for JITKernel class and how it calls get_kernel_source
rg -n -A 10 "class JITKernel" --type py
rg -n -B 2 -A 5 "\.get_kernel_source\(" --type py

Length of output: 74247


🏁 Script executed:

#!/bin/bash
# Find all Python files in jit/adapter directory
fd -t f "\.py$" "jit/adapter"
# Also find the jit directory structure
find . -type d -name "jit" 2>/dev/null
find . -type d -name "adapter" 2>/dev/null

Length of output: 157


🏁 Script executed:

#!/bin/bash
# Examine the base adapter file
cat "tilelang/jit/adapter/base.py" 2>/dev/null || echo "File not found in standard path"
# Check directory structure
ls -la tilelang/jit/adapter/ 2>/dev/null || ls -la tilelang/ 2>/dev/null || find . -maxdepth 3 -type d -name jit

Length of output: 4699


Let me examine the nvrtc adapter and the JITKernel implementation in detail:

[/run_scripts]


🏁 Script executed:

#!/bin/bash
# Get nvrtc adapter implementation
cat tilelang/jit/adapter/nvrtc/adapter.py | head -200

Length of output: 7979


[/run_scripts]


🏁 Script executed:

#!/bin/bash
# Check torch adapter
cat tilelang/jit/adapter/torch/metal.py | grep -A 10 "get_kernel_source" || echo "No get_kernel_source in torch adapter"

# Check JITKernel.get_kernel_source implementation
sed -n '421,435p' tilelang/jit/kernel.py

Length of output: 546


[/run_scripts]


🏁 Script executed:

#!/bin/bash
# Check ctypes adapter get_kernel_source
sed -n '294,310p' tilelang/jit/adapter/ctypes/adapter.py

# Check cython adapter get_kernel_source
sed -n '384,400p' tilelang/jit/adapter/cython/adapter.py

# Check tvm_ffi get_kernel_source
sed -n '311,327p' tilelang/jit/adapter/tvm_ffi.py

Length of output: 1147


Fix the nvrtc adapter signature to accept the kernel_only parameter

The review comment correctly identifies a critical interface mismatch. The nvrtc adapter has an incompatible signature that will cause TypeError at runtime.

Issues confirmed:

  1. nvrtc adapter signature mismatch: The method at tilelang/jit/adapter/nvrtc/adapter.py:171 is def get_kernel_source(self) -> str | None: without the kernel_only parameter. However, JITKernel.get_kernel_source at tilelang/jit/kernel.py:430-431 calls self.adapter.get_kernel_source(kernel_only=kernel_only) for the nvrtc backend. This will raise TypeError: get_kernel_source() got an unexpected keyword argument 'kernel_only' when using execution_backend="nvrtc".

  2. Base adapter module structure assumption: The base implementation accesses self.mod.imports[0] without guards. While child adapters (ctypes, cython, tvm_ffi) all override with their own implementations using stored sources, the nvrtc adapter relies directly on self.device_kernel_source with no self.mod structure. This inconsistency suggests the module import structure isn't consistently guaranteed.

  3. Default parameter inconsistency: The base adapter defaults to kernel_only=True, but all working child adapters (ctypes, cython, tvm_ffi) override with kernel_only=False, indicating the base implementation path is never actually exercised.

Required fix: Update tilelang/jit/adapter/nvrtc/adapter.py:171 to accept the kernel_only parameter consistent with other adapters (likely kernel_only: bool = False), and align its implementation logic with how it stores kernel sources.

🤖 Prompt for AI Agents
In tilelang/jit/adapter/nvrtc/adapter.py around line 171, the nvrtc adapter's
get_kernel_source signature lacks the kernel_only parameter and will raise a
TypeError when called with kernel_only by JITKernel; change the signature to
accept kernel_only: bool = False and implement branching: if kernel_only return
only the device/kernel source (self.device_kernel_source or equivalent), else
return the full combined source (host + device or module + imports) consistent
with other adapters; also defensively handle missing attributes (fall back to
empty string) so the method never raises due to absent self.mod or sources.

Comment on lines +33 to 35
host_kernel_source: str | None = None
device_kernel_source: str | None = None
lib: ctypes.CDLL | None = None # Compiled library handle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Populate host_kernel_source for ctypes backend to keep KernelCache working

In CtypesKernelAdapter.__init__, you now accept host_kernel_source / device_kernel_source, but for the normal compile path only device_kernel_source is provided from JITKernel. The code does:

self.host_kernel_source = host_kernel_source
self.device_kernel_source = device_kernel_source
...
self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True))
...
self.lib_generator.update_lib_code(self.wrapped_source)

get_kernel_source(kernel_only=True) returns self.device_kernel_source, so wrapped_source is correctly derived. However, host_kernel_source remains None in this path.

Later, KernelCache._save_kernel_to_disk persists the host source with:

KernelCache._safe_write_file(
    host_kernel_path, "w",
    lambda file: file.write(kernel.adapter.get_kernel_source()))

Since get_kernel_source() defaults to kernel_only=False, for ctypes it returns self.host_kernel_source, i.e. None for compiled kernels. Writing None to the file will raise a TypeError.

You can fix this by populating host_kernel_source from the wrapped source in __init__:

-        self.wrapper.assign_host_module(host_mod)
-        self.wrapper.assign_device_module(device_mod)
-        self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True))
+        self.wrapper.assign_host_module(host_mod)
+        self.wrapper.assign_device_module(device_mod)
+        self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True))
+        # Keep a copy of the wrapped host source for inspection/caching.
+        self.host_kernel_source = self.wrapped_source

This keeps existing behavior for the compiled path, while the from_database path still overrides host_kernel_source and wrapped_source as needed.

Also applies to: 52-70, 99-106, 112-130, 294-300

Comment on lines +109 to +126
for i, param in enumerate(params):
if param in buffer_map:
buffer = buffer_map[param]
for j, shape in enumerate(buffer.shape):
if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and
(shape not in params)):
dynamic_symbolic_map[shape] = (0, i, j)
for i, param in enumerate(params):
if param in buffer_map:
buffer = buffer_map[param]
for j, stride in enumerate(buffer.strides):
if (isinstance(stride, tir.Var) and (stride not in dynamic_symbolic_map) and
(stride not in params)):
dynamic_symbolic_map[stride] = (1, i, j)
return dynamic_symbolic_map

def _convert_torch_func(self) -> Callable[..., Any]:
# Capture thunks that reflect Torch's current stream and device.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Dynamic shape symbols coming from PrimFunc params make outputs zero-rank

When a PrimFunc carries dynamic dims via explicit scalar params (e.g. n or m), those tir.Var symbols appear in both func.params and the buffer shapes. The guard (shape not in params) prevents _process_dynamic_symbolic from recording them, so _convert_torch_func never finds a match. The subsequent loop silently omits the dimension, leaving shape empty and tripping the “0-dimensional tensors are not supported” error for perfectly valid kernels. This regresses any workload that models dynamic extents with scalar parameters—which is how most of our TIR schedules are emitted.

Please keep those symbols in the map (or otherwise resolve them) and fail loudly if a tir.Var cannot be resolved. One fix is to drop the shape not in params filter and tighten the call site to raise when a symbol is still missing:

-                for j, shape in enumerate(buffer.shape):
-                    if (isinstance(shape, tir.Var) and (shape not in dynamic_symbolic_map) and
-                        (shape not in params)):
-                        dynamic_symbolic_map[shape] = (0, i, j)
+                for j, shape in enumerate(buffer.shape):
+                    if isinstance(shape, tir.Var) and shape not in dynamic_symbolic_map:
+                        dynamic_symbolic_map[shape] = (0, i, j)-                    for s in param_shapes[i]:
-                        if isinstance(s, tir.Var):
-                            for key in dynamic_symbolic_map:
-                                if (str(s) == str(key)):
-                                    ref_id, ref_tensor_idx, ref_shape_idx = dynamic_symbolic_map[key]
-                                    shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx])
-                        else:
-                            shape.append(s)
+                    for s in param_shapes[i]:
+                        if isinstance(s, tir.Var):
+                            if s not in dynamic_symbolic_map:
+                                raise ValueError(
+                                    f"{global_symbol}: cannot resolve dynamic dim '{s}' at output index {i}; "
+                                    "bind it to an input buffer or scalar."
+                                )
+                            ref_id, ref_tensor_idx, ref_shape_idx = dynamic_symbolic_map[s]
+                            ICHECK_EQ(ref_id, 0) << "stride symbols not yet supported here";
+                            shape.append(tensor_list[ref_tensor_idx].shape[ref_shape_idx])
+                        else:
+                            shape.append(s)

Without this, dynamic kernels misbehave immediately.

Also applies to: 217-236

- Changed default execution backend from "cython" to "auto" in multiple locations to allow automatic selection based on the target.
- Expanded the list of supported execution backends to include "torch" and "nvrtc" across various classes and functions.
- Enhanced backend resolution logic in `KernelCache` and `AutoTuner` to ensure appropriate backend selection based on the target.
- Updated documentation to reflect changes in execution backend options and their defaults.
@LeiWang1999 LeiWang1999 requested a review from Hzfengsy November 15, 2025 15:04
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (5)
tilelang/jit/adapter/nvrtc/adapter.py (1)

172-180: Consider clarifying get_kernel_source scope.

Now that kernel sources are split into host and device, get_kernel_source() returns only device_kernel_source. The method name no longer fully describes its behavior.

Options:

  1. Rename to get_device_kernel_source() for clarity
  2. Return both sources as a tuple
  3. Add a docstring note that it returns device source only
tilelang/cache/kernel_cache.py (4)

291-298: Consider using logging.exception for better diagnostics.

The error handling correctly catches exceptions, but as static analysis suggests, using self.logger.exception instead of error would automatically include the traceback, making debugging easier.

Apply this diff:

         except Exception as e:
-            self.logger.error(f"Error saving kernel source code to disk: {e}")
+            self.logger.exception("Error saving kernel source code to disk")

302-314: Consider using logging.exception for better diagnostics.

Same as the previous segment—using self.logger.exception would provide more diagnostic context.

Apply this diff:

         except Exception as e:
-            self.logger.error(f"Error saving host kernel source code to disk: {e}")
+            self.logger.exception("Error saving host kernel source code to disk")

411-425: Consider using logging.exception for better diagnostics.

As with the save methods, using self.logger.exception would include tracebacks, making it easier to diagnose file I/O issues.

Apply this diff:

         except Exception as e:
-            self.logger.error(f"Error loading kernel source code from disk: {e}")
+            self.logger.exception("Error loading kernel source code from disk")
         try:
             # ...
         except Exception as e:
-            self.logger.error(f"Error loading host kernel source code from disk: {e}")
+            self.logger.exception("Error loading host kernel source code from disk")

451-452: Optional: Address the TODO for better diagnostics.

The TODO comment suggests reporting why the kernel load failed. This would help users understand cache misses.

Would you like me to generate code that logs which of the three required components (host_kernel_source, device_kernel_source, kernel_params) failed to load?

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between dacab10 and d09b267.

📒 Files selected for processing (8)
  • tilelang/autotuner/param.py (3 hunks)
  • tilelang/autotuner/tuner.py (4 hunks)
  • tilelang/cache/__init__.py (1 hunks)
  • tilelang/cache/kernel_cache.py (14 hunks)
  • tilelang/jit/__init__.py (9 hunks)
  • tilelang/jit/adapter/nvrtc/adapter.py (5 hunks)
  • tilelang/jit/execution_backend.py (1 hunks)
  • tilelang/jit/kernel.py (16 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tilelang/autotuner/param.py
🧰 Additional context used
🧬 Code graph analysis (6)
tilelang/jit/adapter/nvrtc/adapter.py (3)
tilelang/jit/adapter/nvrtc/wrapper.py (3)
  • host_func (258-262)
  • host_func (265-267)
  • update_lib_code (498-555)
tilelang/jit/adapter/wrapper.py (7)
  • host_func (525-535)
  • wrap (134-135)
  • wrap (818-837)
  • wrap (845-859)
  • update_lib_code (444-493)
  • update_lib_code (731-750)
  • update_lib_code (783-785)
tilelang/jit/adapter/nvrtc/libgen.py (1)
  • NVRTCLibraryGenerator (40-235)
tilelang/cache/kernel_cache.py (8)
tilelang/jit/kernel.py (7)
  • JITKernel (31-756)
  • out_idx (596-597)
  • export_library (611-628)
  • kernel_source (604-605)
  • get_host_source (436-441)
  • get_kernel_source (423-434)
  • from_database (143-183)
tilelang/utils/target.py (1)
  • determine_target (62-123)
tilelang/jit/execution_backend.py (2)
  • resolve_execution_backend (62-100)
  • allowed_backends_for_target (25-55)
tilelang/jit/adapter/tvm_ffi.py (4)
  • func (195-262)
  • get_host_source (299-303)
  • get_kernel_source (311-316)
  • from_database (267-297)
tilelang/autotuner/param.py (2)
  • _load_kernel_from_disk (205-273)
  • _save_kernel_to_disk (146-203)
tilelang/jit/adapter/nvrtc/adapter.py (2)
  • get_kernel_source (172-180)
  • from_database (95-148)
tilelang/jit/adapter/ctypes/adapter.py (3)
  • get_kernel_source (294-300)
  • libpath (280-282)
  • from_database (113-162)
tilelang/jit/adapter/base.py (1)
  • get_kernel_source (93-97)
tilelang/jit/kernel.py (7)
tilelang/jit/adapter/tvm_ffi.py (4)
  • TVMFFIKernelAdapter (23-321)
  • from_database (267-297)
  • get_kernel_source (311-316)
  • get_host_source (299-303)
tilelang/jit/adapter/torch/metal.py (1)
  • MetalKernelAdapter (14-75)
tilelang/engine/param.py (1)
  • KernelParam (12-103)
tilelang/jit/adapter/base.py (2)
  • BaseKernelAdapter (10-100)
  • get_kernel_source (93-97)
tilelang/jit/adapter/nvrtc/adapter.py (2)
  • from_database (95-148)
  • get_kernel_source (172-180)
tilelang/jit/adapter/ctypes/adapter.py (3)
  • from_database (113-162)
  • CtypesKernelAdapter (19-300)
  • get_kernel_source (294-300)
tilelang/jit/adapter/cython/adapter.py (2)
  • from_database (149-208)
  • get_kernel_source (384-391)
tilelang/jit/execution_backend.py (2)
tilelang/language/ast/ir.py (1)
  • target (1682-1713)
tilelang/jit/__init__.py (3)
  • jit (293-294)
  • jit (298-310)
  • jit (313-382)
tilelang/jit/__init__.py (2)
tilelang/jit/execution_backend.py (2)
  • resolve_execution_backend (62-100)
  • allowed_backends_for_target (25-55)
tilelang/language/ast/ir.py (1)
  • target (1682-1713)
tilelang/autotuner/tuner.py (5)
tilelang/utils/target.py (1)
  • determine_target (62-123)
tilelang/jit/__init__.py (3)
  • jit (293-294)
  • jit (298-310)
  • jit (313-382)
tilelang/jit/execution_backend.py (1)
  • resolve_execution_backend (62-100)
tilelang/autotuner/param.py (1)
  • CompileArgs (32-75)
tilelang/jit/kernel.py (1)
  • out_idx (596-597)
🪛 Ruff (0.14.4)
tilelang/jit/adapter/nvrtc/adapter.py

104-104: Unused class method argument: pass_configs

(ARG003)

tilelang/cache/kernel_cache.py

297-297: Do not catch blind exception: Exception

(BLE001)


298-298: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


313-313: Do not catch blind exception: Exception

(BLE001)


314-314: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


416-416: Do not catch blind exception: Exception

(BLE001)


417-417: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


424-424: Do not catch blind exception: Exception

(BLE001)


425-425: Use logging.exception instead of logging.error

Replace with exception

(TRY400)

tilelang/jit/execution_backend.py

51-53: try-except-pass detected, consider logging the exception

(S110)


51-51: Do not catch blind exception: Exception

(BLE001)


90-92: Avoid specifying long messages outside the exception class

(TRY003)


96-98: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (4)
tilelang/jit/adapter/nvrtc/adapter.py (1)

30-46: LGTM: Clean refactoring to device_kernel_source.

The rename from kernel_global_source to device_kernel_source clarifies intent and aligns with the host/device kernel separation introduced in this PR.

Also applies to: 77-77, 80-80

tilelang/cache/kernel_cache.py (3)

16-16: LGTM: Import and constants for TVM FFI support.

The new Executable import and path constants (DEVICE_KERNEL_PATH, HOST_KERNEL_PATH, EXECUTABLE_PATH) properly support the TVM FFI backend's separation of device and host kernel sources.

Also applies to: 22-24


140-157: LGTM: Backend resolution and target normalization.

The new logic properly normalizes the target and resolves the execution backend before caching. The conditional logging avoids duplicates when the backend is explicitly specified and matches the resolved value.


259-263: LGTM: Atomic executable write.

The new _safe_write_executable method correctly implements atomic writes using a temporary file and os.replace, consistent with the existing _safe_write_file pattern.

_lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython"
execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Race condition on shared execution_backend class variable.

The class variable self.execution_backend is modified by concurrent threads (line 94) and read without synchronization in save/load methods (lines 305, 319, 329, 338, 395). Since the lock only protects cache access but not compilation, two threads compiling kernels with different backends can corrupt the cache.

Example scenario:

  1. Thread A compiles with backend="ctypes", sets self.execution_backend="ctypes"
  2. Thread B compiles with backend="tvm_ffi", overwrites self.execution_backend="tvm_ffi"
  3. Thread A saves its ctypes kernel using tvm_ffi paths → wrong library type saved

Apply this diff to use the local execution_backend parameter instead:

     def _save_kernel_to_disk(self,
                              key: str,
                              kernel: JITKernel,
+                             execution_backend: str,
                              func: Callable = None,
                              verbose: bool = False):
-            if self.execution_backend == "tvm_ffi":
+            if execution_backend == "tvm_ffi":
                 KernelCache._safe_write_file(
                     host_kernel_path, "w",
                     lambda file: file.write(kernel.adapter.get_host_source()))
-            if self.execution_backend == "nvrtc":
+            if execution_backend == "nvrtc":
                 kernel_lib_path = KERNEL_CUBIN_PATH
-            elif self.execution_backend == "tvm_ffi":
+            elif execution_backend == "tvm_ffi":
                 kernel_lib_path = EXECUTABLE_PATH
             else:
-            if self.execution_backend == "nvrtc":
+            if execution_backend == "nvrtc":
                 src_lib_path = kernel.adapter.libpath
                 # ... nvrtc handling ...
-            elif self.execution_backend == "tvm_ffi":
+            elif execution_backend == "tvm_ffi":
                 executable = kernel.adapter.executable
                 # ... tvm_ffi handling ...
-        if self.execution_backend == "nvrtc":
+        if execution_backend == "nvrtc":
             kernel_lib_path = KERNEL_CUBIN_PATH
-        elif self.execution_backend == "tvm_ffi":
+        elif execution_backend == "tvm_ffi":
             kernel_lib_path = EXECUTABLE_PATH

Update the call site at line 217:

-                self._save_kernel_to_disk(key, kernel, func, verbose)
+                self._save_kernel_to_disk(key, kernel, execution_backend, func, verbose)

Also remove the assignment at line 94 in _generate_key:

-        self.execution_backend = execution_backend

Also applies to: 94-94, 305-312, 319-324, 329-349, 395-400

Comment on lines +338 to +342
elif self.execution_backend == "tvm_ffi":
executable = kernel.adapter.executable
if verbose:
self.logger.debug(f"Saving kernel executable to file: {executable}")
KernelCache._safe_write_executable(executable, kernel_lib_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix: Log the file path, not the object.

Line 341 logs the executable object rather than the destination path, which will produce unhelpful output like <tvm.runtime.Module object>.

Apply this diff:

             elif self.execution_backend == "tvm_ffi":
                 executable = kernel.adapter.executable
                 if verbose:
-                    self.logger.debug(f"Saving kernel executable to file: {executable}")
+                    self.logger.debug(f"Saving kernel executable to file: {kernel_lib_path}")
                 KernelCache._safe_write_executable(executable, kernel_lib_path)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
elif self.execution_backend == "tvm_ffi":
executable = kernel.adapter.executable
if verbose:
self.logger.debug(f"Saving kernel executable to file: {executable}")
KernelCache._safe_write_executable(executable, kernel_lib_path)
elif self.execution_backend == "tvm_ffi":
executable = kernel.adapter.executable
if verbose:
self.logger.debug(f"Saving kernel executable to file: {kernel_lib_path}")
KernelCache._safe_write_executable(executable, kernel_lib_path)
🤖 Prompt for AI Agents
In tilelang/cache/kernel_cache.py around lines 338 to 342, the debug log
currently interpolates and prints the `executable` object (e.g.
`<tvm.runtime.Module object>`) instead of the destination file path; change the
log to reference the `kernel_lib_path` (or include both `kernel_lib_path` and a
short description) so the debug output shows the file being written to, then
call KernelCache._safe_write_executable as before.

Comment on lines +100 to +101
host_kernel_source: str,
device_kernel_source: str,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused parameters from from_database.

Three parameters are accepted but never used:

  • host_kernel_source (line 100): stored at line 109 but never referenced afterward
  • pass_configs (line 104): never stored or used, confirmed by static analysis

In contrast, __init__ actively uses pass_configs (line 74: self.wrapper.assign_pass_configs(pass_configs)). Since from_database loads from a prebuilt library and doesn't instantiate a wrapper, these parameters serve no purpose.

If these are retained for API symmetry or future extension, consider adding a comment explaining why. Otherwise, remove them:

     @classmethod
     def from_database(cls,
                       params: list[KernelParam],
                       result_idx: list[int],
                       target: str,
                       func_or_mod: tir.PrimFunc | tvm.IRModule,
-                      host_kernel_source: str,
                       device_kernel_source: str,
                       kernel_lib_path: str,
                       verbose: bool = False,
-                      pass_configs: dict[str, Any] | None = None,
                       compile_flags: list[str] | None = None):

And remove the corresponding storage at line 109:

         adapter.params = params
         adapter.result_idx = adapter._legalize_result_idx(result_idx)
-        adapter.host_kernel_source = host_kernel_source
         adapter.device_kernel_source = device_kernel_source

Also applies to: 104-104

🤖 Prompt for AI Agents
In tilelang/jit/adapter/nvrtc/adapter.py around lines 100-101, from_database
currently accepts unused parameters host_kernel_source and pass_configs
(host_kernel_source is stored at line 109 but never used; pass_configs is
neither stored nor used). Remove these two parameters from the from_database
signature, delete the assignment that stores host_kernel_source at line 109, and
remove any no-op handling of pass_configs; if API symmetry or future use is
desired instead of removal, add a brief comment next to the signature explaining
they are intentionally unused. Also update any call sites to stop passing those
arguments.

Comment on lines +48 to +50
from tilelang.jit.adapter.nvrtc import is_nvrtc_available # lazy
if not is_nvrtc_available and "nvrtc" in allowed:
allowed = [b for b in allowed if b != "nvrtc"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Call is_nvrtc_available() when filtering NVRTC

is_nvrtc_available is a callable; without invoking it we never drop NVRTC from allowed, so resolve_execution_backend will happily return "nvrtc" even when that backend can't run and later compilation will crash. Please call the helper before negating it.

-            from tilelang.jit.adapter.nvrtc import is_nvrtc_available  # lazy
-            if not is_nvrtc_available and "nvrtc" in allowed:
+            from tilelang.jit.adapter.nvrtc import is_nvrtc_available  # lazy
+            if not is_nvrtc_available() and "nvrtc" in allowed:
🤖 Prompt for AI Agents
In tilelang/jit/execution_backend.py around lines 48 to 50, the code checks
is_nvrtc_available without calling it, so NVRTC is never filtered out; change
the condition to call the helper (i.e., use is_nvrtc_available()) and then
remove "nvrtc" from allowed when it returns False — keep the lazy import but
invoke the callable in the if-statement and filter the allowed list accordingly.

@Hzfengsy
Copy link
Collaborator

A few benchmark results:

GEMM

Shape (M×N×K) tvm_ffi ctypes cython Best Backend Speedup vs 2nd
256×256×256 17.64 µs 24.80 µs 30.28 µs tvm_ffi 1.41×
512×512×512 13.83 µs 24.32 µs 30.52 µs tvm_ffi 1.76×
1024×1024×1024 14.78 µs 23.60 µs 30.04 µs tvm_ffi 1.60×
2048×2048×2048 18.60 µs 22.41 µs 29.09 µs tvm_ffi 1.20×
4096×4096×4096 15.50 µs 24.80 µs 29.33 µs tvm_ffi 1.60×
1024×2048×512 15.50 µs 23.60 µs 31.71 µs tvm_ffi 1.52×
2048×1024×512 13.83 µs 26.46 µs 29.09 µs tvm_ffi 1.91×
512×4096×1024 15.02 µs 31.47 µs 37.67 µs tvm_ffi 2.10×

Average results:
tvm_ffi: 15.59 µs
ctypes: 25.18 µs
cython: 30.97 µs

MLA

tvm_ffi: 20.63 µs
ctypes: 30.91 µs
cython: 56.26 µs

Copy link
Collaborator

@Hzfengsy Hzfengsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many unexpected changes need to be fixed

- Updated `ExtractFuncInfo` in `rt_mod_cuda.cc` and `rt_mod_hip.cc` to map boolean argument types to int32, ensuring compatibility with device runtime.
- Refactored `BindDLTensor` in `arg_binder.cc` to improve null handling and validation checks for DLTensor parameters, utilizing expression-level guards to prevent dereferencing null pointers.
- Enhanced error checking for buffer shape, strides, and data fields, ensuring robust handling of optional inputs and maintaining consistency across various checks.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/jit/kernel.py (1)

606-607: Fix AttributeError: change kernel_global_source to device_kernel_source

The kernel_source property at line 607 accesses adapter.kernel_global_source, which does not exist on any adapter currently used by this code (ctypes, cython, nvrtc, tvm_ffi). All these adapters have device_kernel_source instead. This will raise an AttributeError when artifact is None.

     @property
     def kernel_source(self) -> str:
-        return self.artifact.kernel_source if self.artifact else self.adapter.kernel_global_source
+        return self.artifact.kernel_source if self.artifact else self.adapter.device_kernel_source
♻️ Duplicate comments (2)
tilelang/jit/adapter/nvrtc/adapter.py (1)

94-148: Unused parameters remain in from_database.

As noted in the previous review, host_kernel_source (line 100) is stored at line 109 but never referenced afterward, and pass_configs (line 104) is neither stored nor used. This remains unresolved.

src/transform/arg_binder.cc (1)

280-301: Review overly permissive bool compatibility logic.

The bit1_ok check at lines 299-300 accepts any dtype with 1-bit width regardless of type code, which could mask type errors by accepting exotic 1-bit types without proper bool semantics. The explicitly allowed type codes (int8, uint8, kDLBool with bits 8 or 1) should be sufficient.

As noted in the previous review, consider removing the overly broad bit1_ok check:

     PrimExpr kdlbool1_ok =
         (v_type_code == code_kdlbool && v_type_bits == bits1 && lanes_ok);
-    // Also accept any dtype whose bitwidth=1, regardless of code, to be
-    // defensive.
-    PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok);
-    cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok;
+    cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok;
🧹 Nitpick comments (2)
src/transform/inject_assumes.cc (1)

67-71: Consider adding observability when skipping assumptions to aid debugging.

The guard correctly prevents injecting assumptions that reference undefined variables, which would create invalid TIR. However, silently skipping with continue makes it difficult to debug when assumptions are unexpectedly omitted during the compilation pipeline.

Adding a debug log when assumptions are skipped would help developers understand why certain assumptions aren't injected and whether this affects downstream optimizations:

if (!tvm::tir::UndefinedVars(simplified).empty()) {
  DLOG(INFO) << "Skipping assumption injection for expression " << e.expr 
             << " due to undefined variables in simplified form: " << simplified;
  continue;
}

Also verify that this behavior change (silently skipping assumptions with undefined variables) doesn't negatively impact optimization passes or correctness guarantees that depend on these assumptions being consistently injected.

src/transform/make_packed_api.cc (1)

431-448: Add logging for undefined variable fallback to improve observability during dynamic shape debugging.

The fallback behavior binding undefined variables to zero is intentional and well-documented. However, LOG(WARNING) is a standard pattern used throughout src/transform/ for similar fallback scenarios. Adding logging here would help detect and debug unexpected undefined variables when dynamic shapes or injected assumptions cause this code path to execute.

  if (ffi::Array<Var> undefined =
          UndefinedVars(func_ptr->body, func_ptr->params);
      undefined.size() != 0) {
+    LOG(WARNING) << "Function " << name_hint << " has " << undefined.size()
+                 << " undefined variables; binding to zero as fallback";
    for (const Var &v : undefined) {
      DataType t = v->dtype;
      PrimExpr zero = make_zero(t);
      seq_init.push_back(LetStmt(v, zero, nop));
    }
  }
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4a6e99b and 4364065.

📒 Files selected for processing (8)
  • src/target/rt_mod_cuda.cc (1 hunks)
  • src/target/rt_mod_hip.cc (1 hunks)
  • src/transform/arg_binder.cc (9 hunks)
  • src/transform/inject_assumes.cc (2 hunks)
  • src/transform/make_packed_api.cc (13 hunks)
  • tilelang/engine/phase.py (1 hunks)
  • tilelang/jit/adapter/nvrtc/adapter.py (5 hunks)
  • tilelang/jit/kernel.py (16 hunks)
✅ Files skipped from review due to trivial changes (1)
  • tilelang/engine/phase.py
🧰 Additional context used
🧬 Code graph analysis (4)
src/transform/arg_binder.cc (1)
tilelang/language/ast/ir.py (3)
  • LetStmt (880-908)
  • handle (1467-1497)
  • decl_buffer (1137-1205)
tilelang/jit/adapter/nvrtc/adapter.py (6)
tilelang/jit/adapter/nvrtc/wrapper.py (3)
  • host_func (258-262)
  • host_func (265-267)
  • update_lib_code (498-555)
tilelang/jit/adapter/wrapper.py (7)
  • host_func (525-535)
  • wrap (134-135)
  • wrap (818-837)
  • wrap (845-859)
  • update_lib_code (444-493)
  • update_lib_code (731-750)
  • update_lib_code (783-785)
tilelang/jit/adapter/nvrtc/libgen.py (1)
  • NVRTCLibraryGenerator (40-235)
tilelang/jit/kernel.py (2)
  • params (602-603)
  • get_kernel_source (423-434)
tilelang/jit/adapter/base.py (1)
  • get_kernel_source (93-97)
tilelang/jit/adapter/ctypes/adapter.py (1)
  • get_kernel_source (294-300)
src/transform/make_packed_api.cc (2)
src/target/codegen_cpp.cc (8)
  • VisitExpr_ (166-179)
  • VisitExpr_ (166-167)
  • VisitExpr_ (375-414)
  • VisitExpr_ (375-376)
  • VisitExpr_ (452-455)
  • VisitExpr_ (452-453)
  • VisitExpr_ (457-460)
  • VisitExpr_ (457-458)
tilelang/language/tir/op.py (1)
  • tvm_struct_get (414-436)
tilelang/jit/kernel.py (5)
tilelang/jit/adapter/tvm_ffi.py (4)
  • TVMFFIKernelAdapter (23-321)
  • from_database (267-297)
  • get_kernel_source (311-316)
  • get_host_source (299-303)
tilelang/jit/adapter/base.py (2)
  • BaseKernelAdapter (10-100)
  • get_kernel_source (93-97)
tilelang/jit/adapter/nvrtc/adapter.py (2)
  • from_database (95-148)
  • get_kernel_source (172-183)
tilelang/jit/adapter/ctypes/adapter.py (3)
  • from_database (113-162)
  • CtypesKernelAdapter (19-300)
  • get_kernel_source (294-300)
tilelang/jit/adapter/cython/adapter.py (2)
  • from_database (149-208)
  • get_kernel_source (384-391)
🪛 Ruff (0.14.4)
tilelang/jit/adapter/nvrtc/adapter.py

104-104: Unused class method argument: pass_configs

(ARG003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
  • GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
  • GitHub Check: Build SDist
🔇 Additional comments (20)
src/transform/inject_assumes.cc (1)

7-7: LGTM!

The include is necessary for the UndefinedVars function used in the new guard logic below.

tilelang/jit/adapter/nvrtc/adapter.py (2)

30-92: LGTM!

The rename from kernel_global_source to device_kernel_source is consistent with the broader refactoring to separate host and device kernel sources. All usages are correctly updated.


172-183: LGTM!

The kernel_only parameter correctly returns device_kernel_source when True and host_func when False, aligning with the refactoring pattern across other adapters.

tilelang/jit/kernel.py (5)

18-18: LGTM!

The addition of TVMFFIKernelAdapter and the default backend change to "tvm_ffi" align with the PR objectives. According to the benchmark results, tvm_ffi shows 1.41-2.10× speedup over other backends.

Also applies to: 54-110


142-183: LGTM!

The signature changes to accept host_kernel_source and device_kernel_source (replacing kernel_global_source) are consistent with the adapter refactoring across the codebase.


228-229: LGTM!

The logic correctly enables host codegen and device compile for the tvm_ffi backend, and all adapter instantiations properly use device_kernel_source. The TVMFFIKernelAdapter receives all required parameters including rt_mod, host_mod, device_mod, and device_kernel_source.

Note: Line 307 still uses kernel_global_source for the Metal adapter, which may be intentional if that adapter hasn't been refactored yet.

Also applies to: 243-316


318-384: LGTM!

The _create_adapter_from_database method correctly handles the tvm_ffi backend and consistently passes host_kernel_source and device_kernel_source to all adapter constructors.


423-443: LGTM!

The get_kernel_source and get_host_source methods correctly include tvm_ffi in the backend checks and delegate to the adapter methods. The assertion at line 442 is good defensive programming.

src/target/rt_mod_hip.cc (1)

38-42: LGTM!

The boolean-to-int32 mapping for device runtime arguments is correctly implemented and well-documented. This aligns with the identical change in rt_mod_cuda.cc and addresses device runtime limitations with bool types.

src/target/rt_mod_cuda.cc (1)

27-31: LGTM!

The boolean parameter handling is correctly implemented and consistent with the HIP version in rt_mod_hip.cc.

src/transform/arg_binder.cc (4)

99-122: LGTM!

The relaxed dtype compatibility logic in BindBuffer is well-structured and appropriately handles:

  • FP8 E4M3 variants binding together for mixed precision support
  • FP8 E5M2 variants compatibility
  • Bool binding from int8/uint8 with matching lanes

The error message at lines 121-122 is also improved to clearly show both expected and actual dtypes.


192-230: LGTM!

The NULL DLTensor* handling is correctly implemented:

  • The is_null check is cached in a Let-bound variable (lines 196-200) to avoid repeated evaluations
  • All dereferences are guarded with if_then_else to return safe defaults when handle is NULL
  • Runtime assertions use Or(is_null, ...) to short-circuit validation when the handle is NULL

This enables optional tensor inputs while maintaining safety and clear error messages.


238-255: LGTM!

The dtype field validation correctly guards all DLTensor field accesses with if_then_else(Not(is_null), ...) to avoid dereferencing NULL handles. Default values are returned when the handle is NULL, ensuring safe evaluation.


310-523: LGTM!

The NULL-guarded field handling throughout the rest of BindDLTensor is consistently and correctly implemented:

  • All DLTensor field accesses use if_then_else(Not(is_null), ..., safe_default) to avoid dereferencing NULL
  • All assertions use Or(is_null, ...) to short-circuit validation when handle is NULL
  • The data pointer NULL check (lines 509-514) correctly exempts size-0 arrays and NULL handles

This comprehensive NULL safety enables optional tensor inputs while maintaining validation for non-NULL cases.

src/transform/make_packed_api.cc (6)

52-128: LGTM!

The ReturnRewriter refactor correctly transitions from the legacy dual-buffer approach to FFI-centric struct packing:

  • Constructor simplified to single ret_var parameter (line 52)
  • ConvertForFFI properly handles bool by mapping to kTVMFFIBool and casting to Int(64) (lines 89-92)
  • WriteToOut correctly packs three fields (type_index, zero_padding, union_value) using tvm_struct_set (lines 111-125)

The implementation aligns with the TVM FFI conventions and handles all relevant primitive types.


136-177: LGTM!

The SubroutineCallRewriter correctly migrates to FFI-centric types (ffi::Map, ffi::String, ffi::Array, ffi::GetRef, ffi::Optional) without altering the core logic. The changes are consistent and maintain the existing functionality.


247-277: LGTM!

The new packed API signature correctly defines the FFI-style parameters:

  • v_self_handle for module context
  • v_packed_args for packed argument array
  • v_num_packed_args for argument count
  • v_result for return value storage

The f_load_arg_value lambda correctly loads arguments from the kTVMFFIAnyUnionValue field and handles type conversion via APIType.


290-293: LGTM!

The NULL pointer check for v_packed_args is appropriately guarded by num_args > 0, avoiding unnecessary assertions when the function takes no arguments.


300-371: LGTM!

The argument type checking and loading logic is correctly implemented:

  • Type index extraction from kTVMFFIAnyTypeIndex field (lines 304-310)
  • Handle type properly adjusts Tensor offsets by 24 bytes to account for TVMFFIObject header (lines 324-331)
  • Bool/int/float type compatibility is appropriately relaxed (bool accepts int, int accepts bool, float accepts int/bool with runtime conversion via Select)
  • Deferred binding pattern (lines 366-371) correctly avoids use-before-def issues

The implementation aligns with FFI conventions and maintains backward compatibility.


374-420: LGTM!

The final function construction is correct:

  • Signature with 4 FFI-style parameters (lines 374-377)
  • Proper binding order: variables before buffers (lines 386-394)
  • Global symbol prefixed with ffi::symbol::tvm_ffi_symbol_prefix (lines 395-401)
  • ReturnRewriter correctly invoked with v_result (line 403)
  • Device context handling preserved (lines 407-420)

Comment on lines +241 to +242
if execution_backend == "tvm_ffi":
# Use TVMFFIKernelAdapter for interoperability with PyTorch via DLPack.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Update outdated comment.

The comment mentions "DLPack" but this PR introduces TVM FFI as the execution backend. Update the comment to reflect the current implementation.

Apply this diff:

-        if execution_backend == "tvm_ffi":
-            # Use TVMFFIKernelAdapter for interoperability with PyTorch via DLPack.
+        if execution_backend == "tvm_ffi":
+            # Use TVMFFIKernelAdapter with TVM FFI for efficient kernel execution.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if execution_backend == "tvm_ffi":
# Use TVMFFIKernelAdapter for interoperability with PyTorch via DLPack.
if execution_backend == "tvm_ffi":
# Use TVMFFIKernelAdapter with TVM FFI for efficient kernel execution.
🤖 Prompt for AI Agents
In tilelang/jit/kernel.py around lines 241 to 242, the inline comment
incorrectly references "DLPack" while the code path is for the "tvm_ffi"
execution backend; update the comment to describe that this branch uses the TVM
FFI adapter for interoperability (e.g., with frameworks like PyTorch) via TVM's
FFI interface rather than DLPack, and replace the outdated wording with a
concise comment reflecting the current TVM FFI implementation.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4364065 and a555165.

📒 Files selected for processing (1)
  • src/transform/make_packed_api.cc (13 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/make_packed_api.cc (2)
src/target/codegen_cpp.cc (8)
  • VisitExpr_ (166-179)
  • VisitExpr_ (166-167)
  • VisitExpr_ (375-414)
  • VisitExpr_ (375-376)
  • VisitExpr_ (452-455)
  • VisitExpr_ (452-453)
  • VisitExpr_ (457-460)
  • VisitExpr_ (457-458)
tilelang/language/tir/op.py (1)
  • tvm_struct_get (414-436)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
  • GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
  • GitHub Check: Build SDist
🔇 Additional comments (13)
src/transform/make_packed_api.cc (13)

23-23: LGTM!

The new includes support the FFI migration. The ffi/extra/module.h header provides FFI infrastructure, and unordered_set is a standard addition for variable tracking.

Also applies to: 36-36


52-52: LGTM!

The simplified constructor signature is consistent with the FFI migration, where a single result variable points to the FFI return structure.


89-92: LGTM!

The boolean type handling follows the established pattern for other primitive types, correctly mapping to kTVMFFIBool and converting to Int(64) for FFI representation.


109-128: LGTM!

The migration from buffer-based to FFI struct-based return handling is implemented correctly. The method properly initializes all FFI structure fields (type_index, zero_padding, value) in the correct sequence.


136-175: LGTM!

The systematic migration to ffi:: namespace types (Map, Optional, String, Array, GetRef) is consistent and maintains the existing logic while aligning with the FFI infrastructure.


181-188: LGTM!

The simplified helper functions provide clean, type-safe assertion creation with string messages. The null-pointer check correctly uses the builtin::isnullptr() intrinsic.


247-250: LGTM!

The new variables align with FFI calling conventions. v_self_handle provides the self-reference handle, and v_result is correctly typed as a void pointer for writing the FFI return structure.


290-293: LGTM!

The defensive null check for v_packed_args prevents potential null pointer dereferences when processing arguments. The conditional check (only when num_args > 0) is an appropriate optimization.


332-365: Verify the type coercion policy aligns with requirements.

The implementation allows flexible type conversions:

  • Boolean parameters accept both kTVMFFIBool and kTVMFFIInt
  • Integer parameters accept both kTVMFFIInt and kTVMFFIBool
  • Float parameters accept kTVMFFIFloat, kTVMFFIInt, and kTVMFFIBool

While this provides compatibility for dynamic language interop, it reduces type safety and could mask errors where incorrect types are passed. Ensure this flexibility is intentional and documented as part of the FFI calling convention.


374-377: LGTM!

The function signature correctly implements the FFI packed calling convention with the standard four parameters: self handle, packed arguments, argument count, and result pointer. The comment clearly documents the signature.


390-401: LGTM!

The buffer binding correctly handles DLTensor parameters with type checking, and the function attributes are properly updated for FFI conventions. The symbol prefix ensures correct FFI naming.


447-448: LGTM!

The buffer map is correctly cleared since the FFI calling convention uses packed arguments instead of buffer parameters. Setting ret_type to Int(32) aligns with the packed function error code convention (returning 0 on success).


324-326: Code is correct; TVMFFIObject size assertion is properly designed.

The static_assert at line 325 is the correct approach for cross-platform safety. TVMFFIObject uses a stable, documented C ABI header with a 24-byte intrusive header on common 64-bit platforms, and the assertion will fail at compile-time if the assumption is incorrect on any target platform. No changes needed.

Comment on lines +264 to 276
auto f_load_arg_value = [&](DataType arg_type, int i) {
ffi::Array<PrimExpr> call_args{
v_packed_args, IntImm(DataType::Int(32), i),
IntImm(DataType::Int(32), builtin::kTVMValueContent)};
IntImm(DataType::Int(32), builtin::kTVMFFIAnyUnionValue)};
// load 64 bit version
DataType api_type = APIType(t);
DataType api_type = APIType(arg_type);
PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args);
// cast to the target version.
if (api_type != t) {
res = Cast(t, res);
if (api_type != arg_type) {
res = Cast(arg_type, res);
}
return res;
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify APIType is available.

The FFI argument loading logic is correct. However, ensure that the APIType function (line 269) is properly defined in the included headers.

Run the following script to verify APIType is defined:


🏁 Script executed:

#!/bin/bash
# Description: Verify APIType function is defined in the codebase

# Search for APIType definition
rg -nP '\bAPIType\s*\(' --type=cpp -C3

Length of output: 588


🏁 Script executed:

#!/bin/bash
# Search for APIType definition (not just calls)
rg -nP 'DataType\s+APIType\s*\(' --type=cpp --type=h

echo "---"

# Check includes at top of make_packed_api.cc
head -50 src/transform/make_packed_api.cc | grep -E '#include|namespace'

Length of output: 620


🏁 Script executed:

#!/bin/bash
# Search in arg_binder.h specifically
echo "=== Searching arg_binder.h ===" 
rg -nP 'APIType' src/transform/arg_binder.h

echo ""
echo "=== Broader search for APIType definition ===" 
# Look for APIType in any context
rg -nP 'APIType' --type=cpp --type=h | grep -E 'inline|DataType|^[^:]*:.*APIType\s*='

Length of output: 157


🏁 Script executed:

#!/bin/bash
# Check entire make_packed_api.cc for APIType definition or macro
echo "=== All APIType occurrences in file ===" 
rg -n 'APIType' src/transform/make_packed_api.cc

echo ""
echo "=== Search for #define APIType ===" 
rg -n '#define.*APIType' --type=cpp --type=h

echo ""
echo "=== Check full context of make_packed_api.cc around line 269 ===" 
sed -n '1,100p' src/transform/make_packed_api.cc | tail -40

Length of output: 1493


🏁 Script executed:

#!/bin/bash
# Search for APIType in all header files (especially tvm includes)
echo "=== Search all headers for APIType ===" 
find . -name "*.h" -type f | xargs grep -l 'APIType' 2>/dev/null

echo ""
echo "=== Direct search in builtin.h ===" 
rg -n 'APIType' ./include/tvm/tir/builtin.h 2>/dev/null || echo "File not found, searching differently..."

echo ""
echo "=== All APIType matches across entire repo ===" 
rg -r 'APIType' --type=cpp --type=h . 2>/dev/null | head -20

Length of output: 6300


🏁 Script executed:

#!/bin/bash
# Get broader context around the lambda function
echo "=== Lines 250-290 of make_packed_api.cc ===" 
sed -n '250,290p' src/transform/make_packed_api.cc

echo ""
echo "=== Search for similar type conversion patterns ===" 
rg -n 'Cast\(.*DataType' src/transform/make_packed_api.cc | head -10

Length of output: 1996


🏁 Script executed:

#!/bin/bash
# Final confirmation - check if APIType might be in global namespace or as extern
echo "=== Check for extern or forward declaration of APIType ===" 
rg -n 'extern.*APIType|forward.*APIType|APIType.*extern' --type=cpp --type=h

echo ""
echo "=== Search in entire repo for any APIType definition pattern ===" 
rg -n 'APIType\s*=' --type=cpp --type=h

echo ""
echo "=== Check namespace around make_packed_api.cc ===" 
sed -n '1,50p' src/transform/make_packed_api.cc | tail -20

Length of output: 578


APIType function is undefined and will cause a compilation error.

The function APIType(arg_type) called at line 269 is not defined anywhere in the codebase. A comprehensive search found no definition in headers, included files, or via macro. This will prevent the code from compiling.

You need to implement the APIType function or import it from an external library. Based on the usage pattern (converting a target DataType to its FFI/API representation), it likely belongs in arg_binder.h or a related utilities header.

🤖 Prompt for AI Agents
In src/transform/make_packed_api.cc around lines 264 to 276, the call to
APIType(arg_type) is undefined and will cause a compile error; implement APIType
to map a target DataType to its FFI/API representation (e.g., promote
integer/float bitwidths to the 64-bit FFI-backed type and preserve
pointer/handle types), add its declaration to an appropriate header (suggest
arg_binder.h or a new util header) and include that header in
make_packed_api.cc, or alternatively import the correct existing header if the
function already exists elsewhere; ensure the function signature accepts
(DataType) and returns DataType, and handle all DataType kinds used here
(int/uint/float of various widths and any union/handle types) so Caller can
safely cast back to arg_type after Call.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
src/transform/arg_binder.cc (1)

286-301: Overly permissive 1-bit dtype check remains from previous review.

The bit1_ok check at lines 297-300 accepts any dtype with 1-bit width regardless of type code, which could mask type errors by allowing exotic 1-bit types without proper bool semantics. This concern was raised in the previous review.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a555165 and 2f0bf8d.

📒 Files selected for processing (1)
  • src/transform/arg_binder.cc (9 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/arg_binder.cc (1)
tilelang/language/ast/ir.py (3)
  • LetStmt (880-908)
  • handle (1467-1497)
  • decl_buffer (1137-1205)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
  • GitHub Check: Build SDist
🔇 Additional comments (2)
src/transform/arg_binder.cc (2)

99-122: LGTM—dtype compatibility relaxation is well-structured.

The dtype_compatible lambda appropriately relaxes dtype checks to support:

  • FP8 E4M3 variants binding together (e4m3, e4m3fn, e4m3fnuz)
  • FP8 E5M2 variants (e5m2, e5m2fnuz)
  • Bool binding from int8/uint8 with matching lanes

This defensive binding improves interoperability for mixed-precision and cross-framework usage. Error messages clearly state expected vs. actual dtype.


192-200: Efficient NULL guard pattern—caches check in a Let-bound variable.

Binding the NULL check to is_null_var prevents redundant isnullptr(handle) calls throughout the validation logic. This improves generated code quality.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/transform/make_packed_api.cc (1)

312-331: Update comment: actual offset is 24 bytes, not 16 bytes.

The comment on line 322 incorrectly states the offset "always equals 16 bytes," but the code applies a 24-byte offset (line 325: static_assert(object_cell_offset == 24)). Update the comment to match the actual offset value.

Suggested fix:

// if type_index is Tensor, we need to add the offset of the TVMFFIObject
// header which always equals 24 bytes, this ensures that T.handle always
// shows up as a DLTensor*
♻️ Duplicate comments (2)
src/transform/make_packed_api.cc (2)

264-276: Critical: APIType function is undefined.

This issue was previously flagged. The APIType(arg_type) call at line 269 will cause a compilation error as the function is not defined anywhere in the codebase.


431-433: Reconsider commenting out undefined variable checks.

This issue was previously flagged. Disabling the undefined variable check could mask genuine bugs and reduce type safety. Consider adding a diagnostic warning or making this behavior configurable rather than silently allowing undefined variables.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8f59fa1 and 0df9ba4.

📒 Files selected for processing (1)
  • src/transform/make_packed_api.cc (13 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/make_packed_api.cc (4)
src/transform/storage_rewrite.cc (9)
  • dtype (712-718)
  • dtype (712-712)
  • func_ptr (1873-1910)
  • func_ptr (1873-1873)
  • i (720-860)
  • var (1153-1176)
  • var (1153-1153)
  • buffer (1425-1458)
  • buffer (1426-1427)
src/target/codegen_cpp.cc (8)
  • VisitExpr_ (166-179)
  • VisitExpr_ (166-167)
  • VisitExpr_ (375-414)
  • VisitExpr_ (375-376)
  • VisitExpr_ (452-455)
  • VisitExpr_ (452-453)
  • VisitExpr_ (457-460)
  • VisitExpr_ (457-458)
src/target/codegen_cuda.cc (4)
  • VisitExpr_ (902-1149)
  • VisitExpr_ (902-902)
  • VisitExpr_ (1151-1172)
  • VisitExpr_ (1151-1151)
tilelang/language/tir/op.py (1)
  • tvm_struct_get (414-436)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
  • GitHub Check: Build SDist
🔇 Additional comments (13)
src/transform/make_packed_api.cc (13)

52-52: LGTM: Simplified constructor signature.

The removal of the ret_tcode parameter aligns with the FFI-centric approach where type information is now embedded in the FFI structure.


89-92: LGTM: Boolean support added correctly.

The boolean type is properly mapped to kTVMFFIBool and cast to Int(64) for FFI storage, consistent with the handling of other primitive types.


109-128: LGTM: FFI structure packing implementation.

The WriteToOut method correctly implements the FFI return value protocol by setting the type index, zero padding, and union value fields using tvm_struct_set calls.


136-177: LGTM: Consistent FFI type migration.

The SubroutineCallRewriter class has been properly updated to use FFI-namespaced types (ffi::Map, ffi::String, ffi::Optional, ffi::Array, ffi::GetRef), maintaining consistency with the overall FFI migration.


181-188: LGTM: Helper functions properly defined.

Both MakeAssertEQ and MakeAssertNotNull are correctly implemented with appropriate inline specifiers and proper use of TIR primitives.


247-250: LGTM: FFI signature variables added.

The addition of v_self_handle and the updated type for v_result correctly reflect the new FFI function signature requirements.


290-293: LGTM: Defensive null pointer check added.

The null check for v_packed_args is appropriate and only performed when num_args > 0, improving error messages for null pointer issues.


332-365: LGTM: Comprehensive type checking and coercion logic.

The argument processing correctly handles boolean, integer, and float types with appropriate type index checks and conversions. The float handling is particularly flexible, allowing implicit conversion from integer or boolean when needed.


374-377: LGTM: Function signature updated for FFI.

The new signature correctly reflects the FFI calling convention with self handle, packed arguments array, argument count, and result pointer.


390-394: LGTM: Buffer binding logic preserved.

The buffer binding correctly uses BindDLTensor and declares buffers with the updated FFI-based argument handling.


395-402: LGTM: Global symbol prefixing for FFI registration.

The function attributes are correctly updated with the FFI symbol prefix, appropriate calling convention, and target host, aligning with the FFI registration requirements.


403-403: LGTM: ReturnRewriter instantiation updated.

The call correctly uses the simplified constructor signature with only the result variable.


435-436: LGTM: Function metadata correctly finalized.

The buffer map is properly cleared and the return type is correctly set to Int(32) for error code returns, consistent with the packed API convention.

#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <unordered_set>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

Verify that <unordered_set> is required.

The include was added but it's not immediately clear from the visible changes whether std::unordered_set is actually used in this file.

Run the following script to confirm usage:


🏁 Script executed:

#!/bin/bash
# Description: Check if unordered_set is used in make_packed_api.cc

rg -n '\bunordered_set\b' src/transform/make_packed_api.cc

Length of output: 27


Remove unused <unordered_set> include at line 36.

The header is included but std::unordered_set is not used anywhere in the file. Removing this unused include will clean up unnecessary dependencies.

🤖 Prompt for AI Agents
In src/transform/make_packed_api.cc at line 36, the include statement for
<unordered_set> is unused; remove the #include <unordered_set> line to eliminate
the unnecessary dependency and rebuild to ensure no missing symbols.

- Improved null handling and validation checks in `BindDLTensor`, ensuring safe dereferencing of pointers.
- Enhanced consistency checks for buffer shape, strides, and data fields, utilizing expression-level guards.
- Updated `MakePackedAPI` to maintain code clarity and consistency in argument handling.
- Minor adjustments in test files to streamline kernel execution and improve readability.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
src/transform/make_packed_api.cc (1)

264-276: Re‑check that APIType is defined and included for FFI loads.

f_load_arg_value depends on APIType(arg_type) (line 269) to map parameter dtypes to the 64‑bit FFI API dtype before calling tvm_struct_get. In a previous revision, APIType was not defined anywhere and caused a compile‑time issue; the call site is still present here.

Please verify that:

  • APIType(DataType) is now implemented (or imported) with the expected signature and behavior, and
  • The appropriate header is included so this call compiles.

If it isn’t, you’ll need to either define APIType (likely alongside other API‑level dtype utilities) or inline the mapping logic used during loading.

#!/bin/bash
# Verify that APIType(DataType) is defined and visible to this TU
rg -nP '\bAPIType\s*\(' --type=cpp --type=h -C3
echo "----"
rg -nP 'DataType\s+APIType\s*\(' --type=cpp --type=h -C3
src/transform/arg_binder.cc (1)

203-205: Null-handle path still dereferences the DLTensor

The new NULL-handling still calls TVMArrayGet/BufferLoad before checking is_null, e.g. PrimExpr v_ndim = TVMArrayGet(...) (Line 203), raw_shape_val = BufferLoad(...) (Lines 336-341), and the stride loaders (Lines 412-475). In TIR these calls execute eagerly; wrapping the result in if_then_else/Or does not short-circuit, so a NULL handle will segfault exactly as before. Please move every struct/Buffer load into statement-level guards (IfThenElse(Not(is_null), …)) or skip the loops entirely when the handle is NULL. This is the same issue coderabbitai[bot] highlighted earlier, but it remains unfixed.

Also applies to: 336-352, 412-475

🧹 Nitpick comments (7)
testing/python/jit/test_tilelang_jit_nvrtc.py (3)

367-367: Consider renaming to reflect the tvm_ffi backend.

The function name run_nvrtc_dynamic_shape and its test wrapper test_nvrtc_dynamic_shape (line 395) suggest NVRTC backend testing, but the implementation now uses "tvm_ffi". This naming inconsistency may confuse developers reviewing or debugging tests.

Consider renaming the function to better reflect its purpose, e.g., run_tvm_ffi_dynamic_shape, or add a comment explaining that this test validates dynamic shape support using the new default backend.


368-368: Consider removing or conditionalizing the print statement.

The host source print appears to be diagnostic output for debugging the tvm_ffi backend migration. In production tests, this clutters output and should either be removed or made conditional (e.g., gated by an environment variable or debug flag).


586-589: Clean up commented code or document the transition.

The commented-out tilelang.testing.main() suggests a transitional state. If the shift to direct test invocation with explicit cache disabling is intentional and permanent, remove the commented line. If this is temporary for debugging the tvm_ffi migration, add a TODO comment explaining the plan to restore the automatic test runner.

Note: The direct invocation only exercises one test case (test_nvrtc_dynamic_shape) rather than all tests in the module, which may reduce test coverage during local development runs.

testing/python/jit/test_tilelang_jit_nullptr.py (1)

93-94: Consider removing or conditionalizing the debug print statement.

The kernel invocation pattern is correct. However, the unconditional print(kernel.get_host_source()) statement appears to be debug code that will add noise to test output.

Consider removing this line or making it conditional:

-    print(kernel.get_host_source())

Or make it conditional on an environment variable or debug flag if source inspection is needed during development.

src/transform/make_packed_api.cc (3)

23-36: Clarify STL includes: <unordered_set> looks unused while std::unordered_map is used.

This file uses std::unordered_map (line 258) but only includes <unordered_set>. It likely compiles via transitive includes, but for correctness and clarity you probably want:

  • Replace <unordered_set> with <unordered_map>, and
  • Drop <unordered_set> entirely unless you actually use std::unordered_set elsewhere.

52-52: ReturnRewriter FFI packing looks consistent but only supports scalar (non-handle) values.

The new ConvertForFFI/WriteToOut path correctly maps bool/int/uint/float/void into TVMFFIAny using kTVMFFI* type indices and tvm_struct_set writes, and the ret(0) convention is preserved. However:

  • ConvertForFFI currently only handles is_bool/is_int/is_uint/is_float/is_void; any handle or other dtype passed to tir.ret will trigger LOG(FATAL).
  • If tir.ret is guaranteed to be used only for scalar return values in this pass, this is fine; otherwise, you may want to extend ConvertForFFI to cover handle/object returns as well.

Also applies to: 84-107, 109-128


300-371: Type-index checks and relaxed dtype compatibility are reasonable; document Tensor handle offset.

The per-argument handling:

  • Loads *.type_index via tvm_struct_get,
  • For handles, accepts None/OpaquePtr/DLTensorPtr/static objects and special-cases kTVMFFITensor by adding sizeof(TVMFFIObject) byte offset,
  • Allows BoolInt and IntFloat via controlled TypeIndex checks and casts,

matches the stated goal of more permissive, but still checked, binding semantics.

Two small suggestions:

  • The comment saying the DLTensor header “always equals 16 bytes” conflicts with static_assert(sizeof(TVMFFIObject) == 24); — consider updating the comment to match the current layout.
  • It may be worth briefly documenting somewhere (even as a comment) that the Tensor case requires this TVMFFIObject offset so future changes to the FFI object layout don’t silently break handle bindings.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e202d38 and f85bb8f.

📒 Files selected for processing (4)
  • src/transform/arg_binder.cc (6 hunks)
  • src/transform/make_packed_api.cc (13 hunks)
  • testing/python/jit/test_tilelang_jit_nullptr.py (2 hunks)
  • testing/python/jit/test_tilelang_jit_nvrtc.py (2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • testing/python/jit/test_tilelang_jit_nvrtc.py
  • testing/python/jit/test_tilelang_jit_nullptr.py
🧬 Code graph analysis (4)
testing/python/jit/test_tilelang_jit_nvrtc.py (3)
tilelang/jit/kernel.py (1)
  • get_host_source (436-443)
tilelang/env.py (1)
  • disable_cache (275-276)
tilelang/language/symbolics.py (1)
  • dynamic (10-21)
testing/python/jit/test_tilelang_jit_nullptr.py (2)
tilelang/env.py (1)
  • disable_cache (275-276)
tilelang/utils/tensor.py (1)
  • map_torch_type (19-36)
src/transform/arg_binder.cc (1)
tilelang/language/ast/ir.py (3)
  • LetStmt (880-908)
  • handle (1467-1497)
  • decl_buffer (1137-1205)
src/transform/make_packed_api.cc (2)
src/target/codegen_cpp.cc (8)
  • VisitExpr_ (166-179)
  • VisitExpr_ (166-167)
  • VisitExpr_ (375-414)
  • VisitExpr_ (375-376)
  • VisitExpr_ (452-455)
  • VisitExpr_ (452-453)
  • VisitExpr_ (457-460)
  • VisitExpr_ (457-458)
tilelang/language/tir/op.py (1)
  • tvm_struct_get (414-436)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
  • GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
  • GitHub Check: Build SDist
🔇 Additional comments (6)
testing/python/jit/test_tilelang_jit_nullptr.py (3)

86-87: LGTM!

The addition of disable_cache() and the kernel object pattern align with the PR's transition to TVM FFI-based execution and ensure deterministic test behavior.


101-109: LGTM!

The kernel invocation pattern is consistently applied across all test cases. The NULL pointer handling (passing None when with_bias=False) is correct for both ptr_null_test and tensor_null_test variants.


112-118: The review comment is incorrect—test_nullptr() is not dead code.

Pytest auto-discovers test functions matching the test_* naming convention regardless of the __main__ block. The function will be executed when running tests via pytest (standard test discovery), while the __main__ block provides an alternative for direct script execution. This dual pattern is common and intentional in test files.

However, there is a legitimate concern worth addressing: the file diverges from the project's established pattern. Other test files (e.g., test_tilelang_cache_matmul.py, test_metal_codegen.py) invoke tilelang.testing.main() in their __main__ blocks. The switch to direct run_test() invocation may bypass framework-level setup or test qualification decorators (like @tilelang.testing.requires_cuda). Verify whether this intentional deviation from the project's test invocation pattern is acceptable for this file.

Likely an incorrect or invalid review comment.

src/transform/make_packed_api.cc (3)

136-138: Switch to ffi::Optional / ffi::Map / ffi::String in SubroutineCallRewriter looks sound.

Using ffi::Optional<Stmt>, ffi::Map<GlobalVar, ffi::String>, ffi::GetRef, and ffi::Array<PrimExpr> here is consistent with the rest of the FFI-centric changes, and Apply is still consumed in MakePackedAPI() the same way. I don’t see functional regressions in this rewriter.

Also applies to: 150-152, 157-160, 175-177


181-188: Assert helpers and improved diagnostics look good.

MakeAssertEQ and MakeAssertNotNull are straightforward and their usage for:

  • num_args equality, and
  • guarding against args == nullptr when num_args > 0

provides clearer error messages without changing semantics.

Also applies to: 283-293


247-251: Overall FFI packed-API wiring looks correct; a few assumptions to confirm.

The reworked packed API:

  • Uses (self_handle, packed_args, num_args, result) as params and sets kCallingConv = kCPackedFunc, kTarget = target_host, and kGlobalSymbol with ffi::symbol::tvm_ffi_symbol_prefix + global_symbol.
  • Binds scalar args via ArgBinder::Bind, buffers via BindDLTensor, then builds the body with ReturnRewriter(v_result), device attrs (using ffi::Any node = ffi::String("default")), a tvm_call_packed(tvm_set_device, ...) prologue when needed, and a final ret(0) error code.
  • Merges init/binding/assert/buffer decl nests, resets buffer_map to empty, and sets ret_type to PrimType(Int(32)), matching the usual packed-func convention.
  • Checks UndefinedVars (now using ffi::Array<Var>) to ensure all used vars are in params.

This all lines up with the new TVM FFI calling convention and looks structurally sound.

The main assumptions to keep in mind/verify are:

  • External callers indeed pass a TVMFFIAny* for result that matches the layout used in ReturnRewriter::WriteToOut.
  • No PrimFunc relying on this pass expects a non‑int32 C return type anymore (since ret_type is fixed to Int(32) and the real value flows via result).

Also applies to: 374-378, 390-402, 403-427, 431-439

@LeiWang1999
Copy link
Member Author

A few benchmark results:

GEMM

Shape (M×N×K) tvm_ffi ctypes cython Best Backend Speedup vs 2nd
256×256×256 17.64 µs 24.80 µs 30.28 µs tvm_ffi 1.41×
512×512×512 13.83 µs 24.32 µs 30.52 µs tvm_ffi 1.76×
1024×1024×1024 14.78 µs 23.60 µs 30.04 µs tvm_ffi 1.60×
2048×2048×2048 18.60 µs 22.41 µs 29.09 µs tvm_ffi 1.20×
4096×4096×4096 15.50 µs 24.80 µs 29.33 µs tvm_ffi 1.60×
1024×2048×512 15.50 µs 23.60 µs 31.71 µs tvm_ffi 1.52×
2048×1024×512 13.83 µs 26.46 µs 29.09 µs tvm_ffi 1.91×
512×4096×1024 15.02 µs 31.47 µs 37.67 µs tvm_ffi 2.10×
Average results: tvm_ffi: 15.59 µs ctypes: 25.18 µs cython: 30.97 µs

MLA

tvm_ffi: 20.63 µs ctypes: 30.91 µs cython: 56.26 µs

Do those numbers include kernel launch overhead?

@Hzfengsy
Copy link
Collaborator

Do those numbers include kernel launch overhead?

Yes, it includes all cpu overheads with kernel launches.

The result is tested with the method:

begin_time = time.time()
kernel(...)
# w/o sync
end_time = time.time()

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/quickstart.py (1)

67-74: Fix critical bug: M=0 creates empty tensors that don't test the kernel.

Setting M = 0 creates empty tensors with zero rows:

  • a has shape (0, 1024)
  • c has shape (0, 1024)

This means no actual matrix multiplication occurs, and the correctness validation at line 81 doesn't test anything meaningful. For a quickstart example demonstrating dynamic shapes, you should still use a valid test size.

Apply this diff to fix:

-M = 0
+M = 1024

Alternatively, if demonstrating multiple dynamic sizes is desired:

-M = 0
+# Test with a valid dynamic size
+M = 1024
 # Create random input tensors on the GPU
 a = torch.randn(M, K, device="cuda", dtype=torch.float16)
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_alloc.py (1)

153-153: Apply consistent pattern to all get_kernel_source() calls for clarity.

Verification confirms line 153 now explicitly passes kernel_only=True, while lines 36, 76, and 115 rely on the default parameter (which is also True). All four calls behave identically. For consistency and clarity, apply the explicit parameter pattern to the other similar test functions:

# Line 36
-    code = kernel.get_kernel_source()
+    code = kernel.get_kernel_source(kernel_only=True)

# Line 76  
-    code = kernel.get_kernel_source()
+    code = kernel.get_kernel_source(kernel_only=True)

# Line 115
-    code = kernel.get_kernel_source()
+    code = kernel.get_kernel_source(kernel_only=True)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 65448b5 and f2f8a28.

📒 Files selected for processing (6)
  • examples/blocksparse_gemm/example_blocksparse_gemm.py (0 hunks)
  • examples/gdn/test_example_gdn_compilation.py (1 hunks)
  • examples/quickstart.py (3 hunks)
  • testing/python/debug/test_tilelang_debug_print.py (1 hunks)
  • testing/python/jit/test_tilelang_jit_nullptr.py (1 hunks)
  • testing/python/language/test_tilelang_language_alloc.py (1 hunks)
💤 Files with no reviewable changes (1)
  • examples/blocksparse_gemm/example_blocksparse_gemm.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • examples/gdn/test_example_gdn_compilation.py
  • testing/python/debug/test_tilelang_debug_print.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • testing/python/jit/test_tilelang_jit_nullptr.py
  • examples/quickstart.py
🧬 Code graph analysis (3)
testing/python/jit/test_tilelang_jit_nullptr.py (1)
tilelang/utils/tensor.py (1)
  • map_torch_type (19-36)
examples/quickstart.py (1)
tilelang/jit/__init__.py (3)
  • jit (293-294)
  • jit (298-310)
  • jit (313-382)
testing/python/language/test_tilelang_language_alloc.py (5)
tilelang/jit/adapter/nvrtc/adapter.py (1)
  • get_kernel_source (172-183)
tilelang/jit/kernel.py (1)
  • get_kernel_source (423-434)
tilelang/jit/adapter/ctypes/adapter.py (1)
  • get_kernel_source (294-300)
tilelang/jit/adapter/tvm_ffi.py (1)
  • get_kernel_source (311-316)
tilelang/jit/adapter/cython/adapter.py (1)
  • get_kernel_source (384-391)
🔇 Additional comments (3)
testing/python/jit/test_tilelang_jit_nullptr.py (1)

86-107: LGTM! Kernel object pattern aligns with FFI backend changes.

The variable naming changes (using kernel consistently) align with the broader PR pattern of using kernel object handles returned by the JIT. The test logic correctly validates null pointer handling for both pointer-based and tensor-based APIs by:

  • Testing the None/null bias case first
  • Testing with an actual bias tensor
  • Verifying correctness against reference computations in both scenarios

This test appropriately validates the null pointer handling improvements mentioned in the PR objectives.

examples/quickstart.py (2)

54-54: Good: Dynamic shape support demonstrated.

The change to M = T.dynamic("m") correctly demonstrates dynamic shape functionality, which aligns with the PR's backend improvements.


8-11: I need to examine the quickstart file more closely to check if the original review comment contained other concerns beyond the pass_configs question.

Pass configs are consistent with codebase pattern; no fixes needed for decorator.

Based on codebase analysis, both TL_DISABLE_TMA_LOWER and TL_DISABLE_WARP_SPECIALIZED are used extensively across 50+ test files, examples, and maintenance scripts. Multiple files contain explicit comments indicating these are known workarounds:

  • testing/python/dynamic/test_tilelang_dynamic_symbolic.py:423: "workaround for hopper tma lower pass"
  • examples/flash_decoding/example_gqa_decode.py:49: "TODO(lei): fix warp specialized and tma lower pass"

The quickstart's use of these pass configs aligns with the established pattern throughout the codebase, making it appropriate for a simplified demonstration. The approach is consistent, not unique to this example.

@oraluben
Copy link
Contributor

oraluben commented Nov 18, 2025

Just a quick test against current version:
metal backend doesn't work with tvm-ffi for now, with patch below:

diff --git a/testing/python/metal/test_metal_codegen.py b/testing/python/metal/test_metal_codegen.py
index 22f4beb8..30aee355 100644
--- a/testing/python/metal/test_metal_codegen.py
+++ b/testing/python/metal/test_metal_codegen.py
@@ -5,7 +5,7 @@ import tilelang.language as T
 import torch
 
 
-@tilelang.jit(execution_backend='torch')
+@tilelang.jit
 def matmul(M, N, K, block_M, block_N, block_K, dtype="float32", accum_dtype="float"):
 
     @T.prim_func
diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py
index d0c27b4c..28519695 100644
--- a/tilelang/engine/lower.py
+++ b/tilelang/engine/lower.py
@@ -161,6 +161,8 @@ def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
         device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda")(device_mod, target)
     elif target.kind.name == "hip":
         device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip")(device_mod, target)
+    elif target.kind.name == "metal":
+        device_mod = tvm.ffi.get_global_func("target.build.metal")(device_mod, target)
     else:
         raise ValueError(f"Target {target.kind.name} is not supported")
 
diff --git a/tilelang/jit/execution_backend.py b/tilelang/jit/execution_backend.py
index fe600002..e7bd6b0f 100644
--- a/tilelang/jit/execution_backend.py
+++ b/tilelang/jit/execution_backend.py
@@ -35,7 +35,7 @@ def allowed_backends_for_target(target: Target, *, include_unavailable: bool = T
     elif kind == "hip":
         allowed = ["tvm_ffi", "cython", "ctypes"]
     elif kind == "metal":
-        allowed = ["torch"]
+        allowed = ["tvm_ffi"]
     elif kind == "c":  # CPU C backend
         allowed = ["cython", "ctypes", "tvm_ffi"]
     else:

gives:

$ python test_metal_codegen.py
/Users/yyc/repo/tilelang/.venv/lib/python3.12/site-packages/tvm_ffi/_optional_torch_c_dlpack.py:625: UserWarning: Failed to load torch c dlpack extension: Ninja is required to load C++ extensions (pip install ninja to get it),EnvTensorAllocator will not be enabled.
  warnings.warn(
================================== test session starts ==================================
platform darwin -- Python 3.12.10, pytest-9.0.1, pluggy-1.6.0
rootdir: /Users/yyc/repo/tilelang
configfile: pyproject.toml
plugins: typeguard-4.3.0
collected 3 items                                                                       

test_metal_codegen.py FFF                                                         [100%]

======================================= FAILURES ========================================
___________________________________ test_gemm_float32 ___________________________________

    @tilelang.testing.requires_metal
    def test_gemm_float32():
>       assert_gemm(1024, 1024, 1024, 16, 16, 16)

test_metal_codegen.py:69: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

M = 1024, N = 1024, K = 1024, block_M = 16, block_N = 16, block_K = 16, dtype = 'float32'
accum_dtype = 'float', atol = 1e-08

    def assert_gemm(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        dtype="float32",
        accum_dtype="float",
        atol=1e-8,
    ):
        jit_kernel = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype)
    
        torch_dtype = getattr(torch, dtype)
        a, b = None, None
        if 'int' in dtype:
            a = torch.randint(100, (M, K), dtype=torch_dtype, device='mps')
            b = torch.randint(100, (K, N), dtype=torch_dtype, device='mps')
        else:
            a = torch.randn(M, K, dtype=torch_dtype, device='mps')
            b = torch.randn(K, N, dtype=torch_dtype, device='mps')
        c = torch.zeros(M, N, dtype=torch_dtype, device='mps')
    
        jit_kernel(a, b, c)
    
>       assert torch.allclose(a @ b, c, atol=atol)
E       AssertionError: assert False
E        +  where False = <built-in method allclose of type object at 0x10bd73778>((tensor([[-0.3413, -1.4359,  0.7669,  ...,  1.0984, -0.3785, -0.7579],\n        [ 1.2392,  0.5752,  0.9057,  ..., -1.6042, -0.5458, -0.8807],\n        [-0.2500,  0.6145, -0.6426,  ...,  1.3714, -1.1920, -0.3133],\n        ...,\n        [-0.1443,  0.8539, -1.6691,  ...,  1.5610, -0.8778, -0.8718],\n        [ 0.0952, -0.6398, -0.2781,  ...,  0.4997,  1.6163, -1.4133],\n        [ 0.5427,  2.6028, -0.3472,  ..., -0.2566,  0.5272, -0.6811]],\n       device='mps:0') @ tensor([[-1.0708e+00, -6.4586e-01, -3.1943e-04,  ..., -9.7524e-01,\n          3.2471e-01, -7.9247e-01],\n        [ 1.3516e+00, -1.6345e+00,  1.7113e-01,  ...,  3.3956e-01,\n         -1.9336e-01,  4.4414e-02],\n        [-2.3392e-01,  1.9414e+00, -8.5446e-01,  ..., -2.4371e-01,\n         -2.0256e+00,  2.2055e+00],\n        ...,\n        [-2.9429e-01, -7.7305e-01, -1.2008e+00,  ...,  1.3850e+00,\n          6.5544e-01, -1.7805e-01],\n        [-6.2572e-02, -1.2158e+00, -1.1049e+00,  ..., -1.8824e+00,\n          3.3176e-01,  7.2061e-01],\n        [ 3.6768e-01,  1.0966e-02,  6.9203e-01,  ..., -1.0592e+00,\n         -2.9840e-01,  1.1227e+00]], device='mps:0')), tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n        [0., 0., 0.,  ..., 0., 0., 0.],\n        [0., 0., 0.,  ..., 0., 0., 0.],\n        ...,\n        [0., 0., 0.,  ..., 0., 0., 0.],\n        [0., 0., 0.,  ..., 0., 0., 0.],\n        [0., 0., 0.,  ..., 0., 0., 0.]], device='mps:0'), atol=1e-08)
E        +    where <built-in method allclose of type object at 0x10bd73778> = torch.allclose

test_metal_codegen.py:61: AssertionError

Looks like tvm-ffi have similar issue like when dealing with raw data pointer from torch on metal, there's extra sync op that not exposed publicly.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
examples/quickstart.py (1)

60-60: LGTM! Good addition for demonstrating kernel source generation.

Printing the kernel source after compilation helps users verify the FFI backend is working correctly, which aligns well with this PR's migration to tvm_ffi as the default execution backend.

Optional: Consider removing redundant commented code.

Lines 80-82 contain commented-out code that would also print the kernel source, which is now redundant with this addition. You may want to remove those lines to avoid confusion:

-# 4. Retrieve and inspect the generated CUDA source (optional)
-# cuda_source = jit_kernel.get_kernel_source()
-# print("Generated CUDA kernel:\n", cuda_source)
testing/python/jit/test_tilelang_jit_tvm_ffi.py (2)

407-412: Add exception handling for device property access.

The function should guard against exceptions when accessing device properties, particularly if CUDA is available but no devices are present.

Apply this diff to add exception handling:

 def check_hopper():
-    if not torch.cuda.is_available():
+    try:
+        if not torch.cuda.is_available():
+            return False
+        props = torch.cuda.get_device_properties(0)
+        compute_capability = props.major, props.minor
+        return compute_capability == (9, 0)
+    except (RuntimeError, AssertionError):
         return False
-    props = torch.cuda.get_device_properties(0)
-    compute_capability = props.major, props.minor
-    return compute_capability == (9, 0)

532-581: Consider removing debug print statements.

The test correctly validates L2 persistent cache annotation, but includes print statements at lines 569 and 581 that may clutter CI output.

Consider removing or gating these print statements:

     # Compile the kernel
     kernel = elementwise_add_with_l2_cache(M, N)
-    print(kernel.get_host_source())
     # Create test tensors
     # Verify correctness
     ref_c = a + b
     tilelang.testing.torch_assert_close(c, ref_c, atol=1e-5, rtol=1e-5)
-
-    print("L2 persistent map test passed!")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f2f8a28 and 7fdfa12.

📒 Files selected for processing (2)
  • examples/quickstart.py (1 hunks)
  • testing/python/jit/test_tilelang_jit_tvm_ffi.py (8 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • testing/python/jit/test_tilelang_jit_tvm_ffi.py
🧬 Code graph analysis (2)
examples/quickstart.py (3)
tilelang/jit/adapter/tvm_ffi.py (1)
  • get_kernel_source (311-316)
tilelang/jit/kernel.py (1)
  • get_kernel_source (423-434)
tilelang/jit/adapter/base.py (1)
  • get_kernel_source (93-97)
testing/python/jit/test_tilelang_jit_tvm_ffi.py (7)
tilelang/jit/kernel.py (2)
  • out_idx (598-599)
  • get_profiler (405-421)
tilelang/jit/adapter/tvm_ffi.py (2)
  • func (195-262)
  • get_host_source (299-303)
tilelang/language/symbolics.py (1)
  • dynamic (10-21)
tilelang/language/annotations.py (1)
  • annotate_l2_hit_ratio (48-54)
tilelang/layout/swizzle.py (1)
  • make_swizzled_layout (65-76)
tilelang/language/copy.py (2)
  • c2d_im2col (105-135)
  • copy (14-102)
tilelang/utils/tensor.py (1)
  • torch_assert_close (202-294)
🔇 Additional comments (7)
testing/python/jit/test_tilelang_jit_tvm_ffi.py (7)

93-93: LGTM: Backend migration to tvm_ffi is consistent.

The migration from ctypes to tvm_ffi execution backend is correctly applied across all test cases. The PR benchmarks show tvm_ffi delivers 1.2–2.1× speedups over ctypes, confirming this is a performance improvement.

Also applies to: 194-194, 265-265, 314-314, 367-367, 488-488, 539-539


236-282: LGTM: Function and variable renames reflect tvm_ffi backend.

The renaming from ctypes to tvm_ffi improves code clarity and aligns with the backend migration.


285-335: LGTM: Multi-stream test provides good concurrency coverage.

The test correctly validates tvm_ffi backend behavior with multiple CUDA streams, an important scenario for concurrent kernel execution.


338-404: LGTM: Dynamic shape testing is thorough.

The function correctly handles symbolic dimensions and tests multiple dynamic shape scenarios (1D, 2D, 3D dynamic), providing strong coverage for TVM FFI's dynamic shape support.


415-467: LGTM: Convolution im2col implementation is correct.

The function correctly implements im2col-based convolution with proper:

  • Output dimension calculations
  • TMA descriptor usage via T.c2d_im2col
  • Swizzled layout annotations for shared memory optimization

470-529: LGTM: TMA descriptor test is well-structured.

The test correctly validates im2col TMA descriptor functionality with:

  • Appropriate gating for Hopper GPU requirements
  • Correct layout transformations (NHWC ↔ NCHW)
  • Proper reference implementation using torch.conv2d

585-585: LGTM: Test runner migration improves flexibility.

Using tilelang.testing.main() enables proper test discovery and allows running the full test suite, which is better than hardcoding a specific test function.

…rsistent cache management

- Introduced functions to set and reset the CUDA stream access policy window, allowing for better control over L2 cache usage.
- Updated runtime files to include new FFI packed functions for managing stream attributes.
- Modified lower_hopper_intrin to incorporate prologue and epilogue statements for L2 cache setup and teardown.
- Enhanced tests to verify the inclusion of new FFI calls in the generated kernel source.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
src/transform/lower_hopper_intrin.cc (1)

35-35: Critical allocation size error: Change 16 elements to 8 elements for TVMFFIAny migration.

This issue was previously identified: the migration from TVMValue (8 bytes) to TVMFFIAny (16 bytes) failed to account for the size difference. The TensorMap descriptor requires 128 bytes (1024 bits).

Current allocation:

  • 16 elements × 16 bytes = 256 bytes

Correct allocation:

  • 8 elements × 16 bytes = 128 bytes

Apply this fix:

-                             {StringImm("tvm_ffi_any"), 16});
+                             {StringImm("tvm_ffi_any"), 8});

This double over-allocation will cause memory corruption and crashes during TensorMap descriptor initialization on Hopper GPUs.

🧹 Nitpick comments (2)
testing/python/jit/test_tilelang_jit_tvm_ffi.py (1)

571-572: Consider breaking long assertion messages.

The assertion messages on lines 571-572 exceed typical line length limits and reduce readability.

Consider formatting as:

-    assert "__tvm_cuda_stream_set_access_policy_window_packed" in source, "Expected __tvm_cuda_stream_set_access_policy_window_packed in the kernel source"
-    assert "__tvm_cuda_stream_reset_access_policy_window_packed" in source, "Expected __tvm_cuda_stream_reset_access_policy_window_packed in the kernel source"
+    assert "__tvm_cuda_stream_set_access_policy_window_packed" in source, (
+        "Expected __tvm_cuda_stream_set_access_policy_window_packed in the kernel source"
+    )
+    assert "__tvm_cuda_stream_reset_access_policy_window_packed" in source, (
+        "Expected __tvm_cuda_stream_reset_access_policy_window_packed in the kernel source"
+    )
src/runtime/runtime.cc (1)

16-20: Remove #if 1 guard or clarify intent.

The #if 1 preprocessor directive on line 16 is unconditional and typically indicates temporary debugging code or a feature toggle. If the thread-local storage for L2 cache restoration is intended to be permanent, remove the guard. Otherwise, add a comment explaining the condition.

-#if 1
-// Thread-local storage for restoring the L2 persisting cache limit
 static thread_local size_t __tl_prev_persisting_l2_cache_size = 0;
 static thread_local bool __tl_prev_persisting_l2_cache_saved = false;
-#endif
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7fdfa12 and f177025.

📒 Files selected for processing (4)
  • src/runtime/runtime.cc (4 hunks)
  • src/runtime/runtime.h (1 hunks)
  • src/transform/lower_hopper_intrin.cc (2 hunks)
  • testing/python/jit/test_tilelang_jit_tvm_ffi.py (1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • src/transform/lower_hopper_intrin.cc
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.

Applied to files:

  • src/runtime/runtime.cc
  • testing/python/jit/test_tilelang_jit_tvm_ffi.py
🧬 Code graph analysis (2)
src/transform/lower_hopper_intrin.cc (2)
src/transform/lower_tile_op.cc (2)
  • f (197-224)
  • f (197-197)
src/runtime/runtime.cc (4)
  • args (47-77)
  • args (47-47)
  • args (131-166)
  • args (131-131)
testing/python/jit/test_tilelang_jit_tvm_ffi.py (4)
src/runtime/runtime.h (1)
  • tvm (10-26)
tilelang/utils/tensor.py (2)
  • map_torch_type (19-36)
  • torch_assert_close (202-294)
tilelang/jit/adapter/tvm_ffi.py (4)
  • prim_func (319-321)
  • get_kernel_source (311-316)
  • func (195-262)
  • get_host_source (299-303)
tilelang/language/annotations.py (2)
  • annotate_layout (26-37)
  • annotate_l2_hit_ratio (48-54)
🔇 Additional comments (4)
testing/python/jit/test_tilelang_jit_tvm_ffi.py (1)

88-91: Callback registration persists across tests.

The global function tilelang_callback_cuda_postproc is registered with override=True but never restored. This is consistent with the expected persistent behavior noted in similar test files.

Based on learnings.

src/transform/lower_hopper_intrin.cc (1)

56-96: L2 persistent cache window logic is well-structured.

The new logic to materialize TVM FFI calls for L2 cache access policy window management is clear and correctly handles:

  • Building a name-to-buffer lookup from the function's buffer map
  • Extracting base pointers and arguments from the l2_persistent_map attribute
  • Properly ordering prologue (set policy) and epilogue (reset policy) calls
src/runtime/runtime.cc (2)

216-346: Comprehensive L2 cache access policy window implementation.

The new FFI functions for CUDA stream access policy window management are well-implemented with:

  • Thorough argument validation
  • Device capability querying and clamping
  • Proper state preservation via thread-local storage
  • Comprehensive error handling with informative messages
  • Correct CUDA API usage for stream attributes and persisting L2 cache

The thread-local storage pattern ensures per-thread restoration state, which is appropriate for multi-threaded environments where each thread may manage different streams.


101-114: FFI registration updated to use canonical names.

The migration from string-based registration ("tvm_tensormap_create_tiled") to canonical constant-based registration (tl::tvm_tensormap_create_tiled) improves type safety and maintainability by centralizing the naming in runtime.h.

Comment on lines +20 to +24
// CUDA stream access policy window helpers
constexpr const char *tvm_cuda_stream_set_access_policy_window =
"__tvm_cuda_stream_set_access_policy_window";
constexpr const char *tvm_cuda_stream_reset_access_policy_window =
"__tvm_cuda_stream_reset_access_policy_window";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

Verify CUDA version requirements for stream access policy window APIs.

The new constants tvm_cuda_stream_set_access_policy_window and tvm_cuda_stream_reset_access_policy_window are declared outside the #if (CUDA_MAJOR_VERSION >= 12) guard (lines 13-18), unlike the TMA descriptor constants above them. Ensure this is intentional and that the CUDA stream access policy window APIs (cuStreamSetAttribute with CU_STREAM_ATTRIBUTE_ACCESS_POLICY_WINDOW, cuCtxSetLimit, cuCtxResetPersistingL2Cache) are available in CUDA versions prior to 12.


🌐 Web query:

CUDA stream access policy window cuStreamSetAttribute minimum CUDA version

💡 Result:

The stream "access policy window" (cudaStreamAttributeAccessPolicyWindow / accessPolicyWindow used with cudaStreamSetAttribute / cuLaunchAttribute) was introduced in CUDA 11.0 (documented in the CUDA 11 features announcement and in the CUDA Programming Guide). [1][2]

Sources:

  • NVIDIA blog "CUDA 11 Features Revealed". [1]
  • NVIDIA CUDA Programming Guide (examples showing cudaStreamSetAttribute accessPolicyWindow; present in CUDA 11.x documentation). [2]

Add CUDA version guard for stream access policy window constants.

The stream access policy window APIs were introduced in CUDA 11.0, but the constants tvm_cuda_stream_set_access_policy_window and tvm_cuda_stream_reset_access_policy_window (lines 20–24) lack any version guard. They should be wrapped with #if (CUDA_MAJOR_VERSION >= 11) to prevent use in earlier CUDA versions where these APIs are unavailable.

🤖 Prompt for AI Agents
In src/runtime/runtime.h around lines 20 to 24, the CUDA stream access policy
window constant definitions are unguarded and may be referenced on CUDA versions
older than 11; wrap the two constants tvm_cuda_stream_set_access_policy_window
and tvm_cuda_stream_reset_access_policy_window inside a compile-time guard such
as #if (CUDA_MAJOR_VERSION >= 11) ... #endif so they are only defined when
building against CUDA 11.0+; ensure the guard exactly encloses the constexpr
declarations and retains existing formatting and indentation.

Comment on lines +117 to +160
def matmu_jit_kernel(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])

return main
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Code duplication: matmu_jit_kernel is identical to matmul.

The function matmu_jit_kernel (lines 117-160) is a duplicate of the matmul function (lines 9-52) with only a typo in the name. This violates DRY principles and creates maintenance burden.

Apply this diff to remove the duplicate:

-def matmu_jit_kernel(
-    M,
-    N,
-    K,
-    block_M,
-    block_N,
-    block_K,
-    trans_A,
-    trans_B,
-    in_dtype,
-    out_dtype,
-    accum_dtype,
-    num_stages,
-    threads,
-):
-    A_shape = (K, M) if trans_A else (M, K)
-    B_shape = (N, K) if trans_B else (K, N)
-    A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
-    B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
-
-    @T.prim_func
-    def main(
-            A: T.Tensor(A_shape, in_dtype),
-            B: T.Tensor(B_shape, in_dtype),
-            C: T.Tensor((M, N), out_dtype),
-    ):
-        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
-            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
-            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
-            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
-            T.clear(C_local)
-            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
-                if trans_A:
-                    T.copy(A[k * block_K, by * block_M], A_shared)
-                else:
-                    T.copy(A[by * block_M, k * block_K], A_shared)
-                if trans_B:
-                    T.copy(B[bx * block_N, k * block_K], B_shared)
-                else:
-                    T.copy(B[k * block_K, bx * block_N], B_shared)
-                T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
-            T.copy(C_local, C[by * block_M, bx * block_N])
-
-    return main
-
-
 def run_gemm_jit_kernel(
     M,
     N,
     K,
     trans_A,
     trans_B,
     in_dtype,
     out_dtype,
     dtypeAccum,
     block_M,
     block_N,
     block_K,
     num_stages=3,
     num_threads=128,
 ):
-    program = matmu_jit_kernel(
+    program = matmul(
         M,
         N,
         K,

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
testing/python/jit/test_tilelang_jit_tvm_ffi.py lines 117-160: the function
matmu_jit_kernel is a duplicate (typo) of the existing matmul function; remove
the duplicate definition and update any local references to call matmul instead
(or add a one-line alias like matmu_jit_kernel = matmul if backward-compatible
name must be preserved), ensuring there are no remaining imports/usages of the
removed function.

- Added `codegen_c_host.cc` to the list of source files in CMakeLists.txt for improved code generation support.
- Updated the function call in `lower.py` to use `target.build.tilelang_c` for C target host code generation, enhancing compatibility.
- Marked the TVM subproject as dirty to indicate local modifications.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
src/transform/arg_binder.cc (1)

317-337: Reconsider accepting arbitrary 1‑bit dtypes for bool compatibility.

The bool branch currently allows any 1‑bit dtype, regardless of type code:

PrimExpr kdlbool1_ok =
    (v_type_code == code_kdlbool && v_type_bits == bits1 && lanes_ok);
// Also accept any dtype whose bitwidth=1, regardless of code, to be
// defensive.
PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok);
cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok;

This is quite permissive and could mask type issues if a future/custom 1‑bit type is introduced that doesn’t have bool semantics. Restricting the accepted set to the explicitly listed codes (int8/uint8/kDLBool) keeps the behavior predictable while still handling common frontend quirks.

You can drop the bit1_ok term:

-    PrimExpr kdlbool1_ok =
-        (v_type_code == code_kdlbool && v_type_bits == bits1 && lanes_ok);
-    // Also accept any dtype whose bitwidth=1, regardless of code, to be
-    // defensive.
-    PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok);
-    cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok;
+    PrimExpr kdlbool1_ok =
+        (v_type_code == code_kdlbool && v_type_bits == bits1 && lanes_ok);
+    cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok;
🧹 Nitpick comments (4)
src/target/codegen_c_host.cc (1)

350-453: AssertStmt formatting is helpful; consider dtype‑aware value printing

The enriched assert handling (collecting all EQ subexpressions, building a single message, escaping format strings, and using TVMFFIErrorSetRaisedFromCStr) is a solid usability improvement.

One thing to consider: the printed values always go through (long long) and %lld, regardless of whether the operands are signed, unsigned, or floating‑point. For non‑integer or wider types this can produce misleading values or implementation‑defined casts.

If this becomes an issue in practice, a follow‑up could:

  • Inspect eq->a.dtype()/eq->b.dtype() and choose format specifiers (%lld, %llu, %f, etc.) accordingly, or
  • Restrict the “got/expected” detail to clearly integral dtypes and fall back to the base message otherwise.

Not blocking, but worth keeping in mind if you start asserting on non‑integer expressions.

src/target/codegen_c_host.h (1)

55-69: Tighten header API: unused AddFunctionsOrdered + FFI dependencies

Two small points about the header:

  1. AddFunctionsOrdered is declared but doesn’t have a definition in codegen_c_host.cc. It’s harmless today because nothing calls it, but it will become a linker landmine if used later. Either implement it or drop it from the public interface until you need it.

  2. The class interface uses tvm::ffi::String and tvm::ffi::Array in method signatures and the function_names_ member but the header relies on transitive includes for their declarations. To make this header self‑contained (and friendlier to external users), consider explicitly including the relevant FFI forward/definition headers (or adding forward declarations if TVM provides them).

Both are non‑blocking but would make the API surface safer and clearer.

Also applies to: 91-119

src/transform/arg_binder.h (1)

157-159: Clarify BindNullable API contract and naming.

The new BindNullable API is fine functionally, but it’s undocumented in the header and uses with_lets while Bind uses with_let. Adding a short comment explaining nullable_guard semantics (e.g., “guard expression that disables asserts/bindings when true”) and aligning the naming with Bind/Bind_ would make the public interface clearer.

src/transform/arg_binder.cc (1)

136-159: Tighten FP8 dtype compatibility by also checking lane count.

The dtype_compatible helper relaxes FP8 compatibility:

if (expected.is_float8_e4m3()) {
  return provided.is_float8_e4m3() || provided.is_float8_e4m3fn() ||
         provided.is_float8_e4m3fnuz();
}
if (expected.is_float8_e5m2()) {
  return provided.is_float8_e5m2() || provided.is_float8_e5m2fnuz();
}

Unlike the bool path, these branches don’t enforce expected.lanes() == provided.lanes(), so vectorized FP8 buffers with mismatched lanes could bind successfully. To keep the relaxation limited to “closely related types” without silently accepting lane mismatches, consider requiring equal lanes here as well.

For example:

-    if (expected.is_float8_e4m3()) {
-      return provided.is_float8_e4m3() || provided.is_float8_e4m3fn() ||
-             provided.is_float8_e4m3fnuz();
-    }
+    if (expected.is_float8_e4m3()) {
+      if (expected.lanes() != provided.lanes())
+        return false;
+      return provided.is_float8_e4m3() || provided.is_float8_e4m3fn() ||
+             provided.is_float8_e4m3fnuz();
+    }
@@
-    if (expected.is_float8_e5m2()) {
-      return provided.is_float8_e5m2() || provided.is_float8_e5m2fnuz();
-    }
+    if (expected.is_float8_e5m2()) {
+      if (expected.lanes() != provided.lanes())
+        return false;
+      return provided.is_float8_e5m2() || provided.is_float8_e5m2fnuz();
+    }
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f177025 and 2e6f110.

📒 Files selected for processing (6)
  • CMakeLists.txt (1 hunks)
  • src/target/codegen_c_host.cc (1 hunks)
  • src/target/codegen_c_host.h (1 hunks)
  • src/transform/arg_binder.cc (7 hunks)
  • src/transform/arg_binder.h (1 hunks)
  • tilelang/engine/lower.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
src/transform/arg_binder.cc (1)
tilelang/language/ast/ir.py (3)
  • LetStmt (880-908)
  • handle (1467-1497)
  • decl_buffer (1137-1205)
src/target/codegen_c_host.cc (1)
src/target/codegen_c_host.h (2)
  • CodeGenCHost (42-119)
  • GenerateForwardFunctionDeclarations (83-89)
src/transform/arg_binder.h (1)
src/transform/arg_binder.cc (2)
  • BindNullable (54-89)
  • BindNullable (54-56)
src/target/codegen_c_host.h (1)
src/target/codegen_c_host.cc (35)
  • CodeGenCHost (42-44)
  • Init (46-63)
  • Init (46-48)
  • InitGlobalContext (65-68)
  • InitGlobalContext (65-65)
  • AddFunction (74-77)
  • AddFunction (74-75)
  • AddFunction (79-106)
  • AddFunction (79-81)
  • DefineModuleName (70-72)
  • DefineModuleName (70-70)
  • PrintType (137-211)
  • PrintType (137-137)
  • PrintFuncPrefix (131-135)
  • PrintFuncPrefix (131-131)
  • VisitExpr_ (213-226)
  • VisitExpr_ (213-214)
  • VisitExpr_ (315-348)
  • VisitExpr_ (315-316)
  • VisitExpr_ (455-458)
  • VisitExpr_ (455-456)
  • VisitExpr_ (460-463)
  • VisitExpr_ (460-461)
  • VisitStmt_ (350-453)
  • VisitStmt_ (350-350)
  • GenerateForwardFunctionDeclarations (108-129)
  • GenerateForwardFunctionDeclarations (108-110)
  • GetPackedName (296-313)
  • GetPackedName (296-296)
  • PrintGetFuncFromBackend (228-246)
  • PrintGetFuncFromBackend (228-229)
  • PrintCallPacked (248-294)
  • PrintCallPacked (248-248)
  • PrintTernaryCondExpr (466-477)
  • PrintTernaryCondExpr (466-467)
🔇 Additional comments (5)
CMakeLists.txt (1)

135-146: Wire codegen_c_host.cc into build – looks correct

Adding src/target/codegen_c_host.cc to TILE_LANG_SRCS cleanly integrates the new C host codegen into the build with no side effects on existing sources.

tilelang/engine/lower.py (1)

137-152: Align host/device codegen with TVM‑FFI build entries; verify symbol availability

Switching the C host path to "target.build.tilelang_c" and adding the Metal path via "target.build.metal" keeps the lowering logic consistent with the new TVM‑FFI build APIs and looks structurally sound.

Please double‑check at runtime that both FFI symbols are actually registered (e.g., tvm.ffi.get_global_func("target.build.tilelang_c") and "target.build.metal" return non‑None) to avoid late GlobalFunc not found errors, especially on Metal where you’ve seen execution issues.

Also applies to: 170-188

src/target/codegen_c_host.cc (1)

42-72: Make module_name_ consistent with the emitted library‑context symbol

module_name_ is initialized via name_supply_->FreshName(tvm::ffi::symbol::tvm_ffi_library_ctx), while InitGlobalContext hard‑codes the global variable name using tvm::ffi::symbol::tvm_ffi_library_ctx. PrintGetFuncFromBackend then passes module_name_ into TVMBackendGetFuncFromEnv.

This effectively relies on FreshName returning the bare symbol name for the first use; if its behavior changes (or if the name is consumed earlier), you could end up calling TVMBackendGetFuncFromEnv with an identifier that has no corresponding definition in the generated C (e.g., tvm_ffi_library_ctx0).

To make this robust and avoid surprising compile‑time breakages, consider either:

  • Setting module_name_ directly from the symbol constant and using it consistently:
- CodeGenCHost::CodeGenCHost() {
-   module_name_ = name_supply_->FreshName(tvm::ffi::symbol::tvm_ffi_library_ctx);
- }
+ CodeGenCHost::CodeGenCHost() {
+   module_name_ = tvm::ffi::symbol::tvm_ffi_library_ctx;
+ }

and optionally rewriting InitGlobalContext to emit module_name_, or

  • Dropping module_name_ entirely and using tvm::ffi::symbol::tvm_ffi_library_ctx directly in PrintGetFuncFromBackend.

Also, DefineModuleName() is currently never called; either wire it into initialization/Finish() or remove it to avoid dead API surface.
[selective_essential_refactor:=suggest_essential_refactor]

Also applies to: 228-246

src/transform/arg_binder.cc (2)

54-89: BindNullable implementation matches Bind_ semantics and uses guard correctly.

The implementation mirrors Bind_ for both Var and non‑Var expressions, and the nullable_guard is only threaded into the equality checks via Or(nullable_guard, ...), so first-bind behavior is unchanged while later binds gain the intended short‑circuiting. This looks correct for the current BindDLTensor usage.


229-611: NULL‑aware DLTensor binding and guards look consistent and robust.

The new is_null handling in BindDLTensor (Let‑bound *_is_null, guarded TVMArrayGet via if_then_else, and use of BindNullable for shape/strides/elem_offset/device_id/data) gives a clear and consistent story:

  • Optional DLTensor* handles can be NULL without dereferences, while asserts are short‑circuited with Or(is_null, ...).
  • Shapes/strides/byte_offset/device_type are only validated when non‑NULL, and symbolic vars are still bound via guarded expressions.
  • The data pointer check correctly skips size‑0 arrays and NULL handles.

Overall this is a solid improvement in safety and error reporting for optional inputs.

@LeiWang1999
Copy link
Member Author

overheads_plot_ablation

@LeiWang1999 LeiWang1999 merged commit 74da369 into tile-ai:main Nov 18, 2025
5 of 7 checks passed
RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
* [Refactor] Update FFI type handling and simplify argument management

* Refactored FFI type definitions in runtime and code generation files to use `TVMFFIAny` instead of `TVMValue`, enhancing type clarity.
* Updated function registration in `runtime.cc` to utilize canonical names for better consistency.
* Simplified argument handling in the `simplify` transformation, ensuring unused buffer parameters are removed only when simplification is enabled.
* Adjusted autotuner and profiler parameters to standardize the execution backend to `tvm_ffi`, improving clarity in backend selection.
* Removed obsolete `adapt_torch2tvm` function from tensor utilities to streamline the codebase and reduce complexity.

* [Update] Sync TVM submodule and enhance kernel source handling

* Updated the TVM submodule to commit cdc2aced, ensuring compatibility with recent changes.
* Added functionality to print kernel source in `example_blocksparse_gemm.py` for better debugging.
* Commented out the main execution call in test files to prevent unintended execution during testing.
* Introduced `tilelang.disable_cache()` in various test files to streamline testing and avoid cache-related issues.
* Refactored kernel source retrieval methods to improve clarity and consistency across different execution backends.

* [Refactor] Clean up imports and improve code formatting

* Removed unused import of `tilelang.testing` in `test_example_blocksparse_gemm.py` to streamline the code.
* Reformatted several lines in `arg_binder.cc`, `make_packed_api.cc`, `tvm_ffi.py`, and `adapter.py` for improved readability and consistency.
* Updated comments and spacing in `tvm_ffi.py` to enhance clarity without altering functionality.

* Update execution backend options and improve resolution logic

- Changed default execution backend from "cython" to "auto" in multiple locations to allow automatic selection based on the target.
- Expanded the list of supported execution backends to include "torch" and "nvrtc" across various classes and functions.
- Enhanced backend resolution logic in `KernelCache` and `AutoTuner` to ensure appropriate backend selection based on the target.
- Updated documentation to reflect changes in execution backend options and their defaults.

* lint fix

* fix

* Enhance argument handling in CUDA and HIP runtime modules

- Updated `ExtractFuncInfo` in `rt_mod_cuda.cc` and `rt_mod_hip.cc` to map boolean argument types to int32, ensuring compatibility with device runtime.
- Refactored `BindDLTensor` in `arg_binder.cc` to improve null handling and validation checks for DLTensor parameters, utilizing expression-level guards to prevent dereferencing null pointers.
- Enhanced error checking for buffer shape, strides, and data fields, ensuring robust handling of optional inputs and maintaining consistency across various checks.

* lint fix

* lint fix

* lint fix

* lint fix

* minor fix

* fix

* recover check

* Refactor argument binding and validation in `arg_binder.cc`

- Improved null handling and validation checks in `BindDLTensor`, ensuring safe dereferencing of pointers.
- Enhanced consistency checks for buffer shape, strides, and data fields, utilizing expression-level guards.
- Updated `MakePackedAPI` to maintain code clarity and consistency in argument handling.
- Minor adjustments in test files to streamline kernel execution and improve readability.

* lint fix

* stride fix

* minor fix

* fix

* lint fix

* lint fix

* Add CUDA stream access policy window helpers and integrate with L2 persistent cache management

- Introduced functions to set and reset the CUDA stream access policy window, allowing for better control over L2 cache usage.
- Updated runtime files to include new FFI packed functions for managing stream attributes.
- Modified lower_hopper_intrin to incorporate prologue and epilogue statements for L2 cache setup and teardown.
- Enhanced tests to verify the inclusion of new FFI calls in the generated kernel source.

* check with symbolic

* support null ptr

* Update CMakeLists and lower.py for code generation and subproject status

- Added `codegen_c_host.cc` to the list of source files in CMakeLists.txt for improved code generation support.
- Updated the function call in `lower.py` to use `target.build.tilelang_c` for C target host code generation, enhancing compatibility.
- Marked the TVM subproject as dirty to indicate local modifications.

* lint fix

* Update comments for clarity in quickstart.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants