Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Sep 22, 2025

  • Updated init_desc_arg_map to use Var as the key instead of String in lower_hopper_intrin.cc.
  • Enhanced func_call_args method in TLCUDASourceWrapper to accept additional parameters for better argument mapping.
  • Added assertions to ensure consistency between function parameters and arguments during kernel launches.
  • Modified generate_tma_descriptor_args to utilize a mapping of variable names for TMA descriptor initialization.

Summary by CodeRabbit

  • New Features
    • Exposes the primary host function via a public property in CUDA/NVRTC wrappers for easier access.
  • Bug Fixes
    • Ensures descriptor arguments are correctly mapped to device function parameters with validation to prevent mismatches at runtime.
  • Refactor
    • Streamlines propagation of function parameters through dispatch and kernel launch paths for more reliable code generation.
    • Improves descriptor handling with variable-based mappings, enabling safer and more consistent initialization across kernels.

…trin.cc

- Updated `init_desc_arg_map` to use `Var` as the key instead of `String` in `lower_hopper_intrin.cc`.
- Enhanced `func_call_args` method in `TLCUDASourceWrapper` to accept additional parameters for better argument mapping.
- Added assertions to ensure consistency between function parameters and arguments during kernel launches.
- Modified `generate_tma_descriptor_args` to utilize a mapping of variable names for TMA descriptor initialization.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 22, 2025

Walkthrough

Refactors LowerHopperIntrin to key descriptor-arg maps by Var instead of String. Extends CUDA wrapper to propagate device function parameters, map descriptor variables, and validate descriptor initialization. Adds host_func property, extracts function_params via post-order traversal, updates TMA descriptor generation, and wires mappings through dispatch and launch paths.

Changes

Cohort / File(s) Summary of changes
Hopper intrinsic map key update
src/transform/lower_hopper_intrin.cc
Changed init_desc_arg_map type from Map<String, Array> to Map<Var, Array> and updated Set(...) usage; attaches map to tma_descriptor_args with Var keys.
CUDA wrapper: function params and descriptor mapping
tilelang/jit/adapter/wrapper.py
- Imported post_order_visit.
- Added extraction and propagation of function_params for device kernels (via host func call-site analysis).
- Extended func_call_args to accept function_params and desc_name_var_map; added consistency checks.
- Updated generate_tma_descriptor_args signature to include desc_name_var_map and validate descriptor variables against tma_descriptor_args.
- Plumbed mappings through dispatch/create_dispatch_func and init_tma_descriptor_args.
- Added host_func property on TLCUDASourceWrapper and TLNVRTCSourceWrapper.
- function_informations now stores "function_params".

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant HostWrapper as Host Wrapper
  participant HostFunc as host_func (PrimFunc)
  participant Analyzer as post_order_visit
  participant Dispatch as create_dispatch_func
  participant CUDA as CUDA Codegen
  participant Kernel as Device Kernel

  HostWrapper->>HostFunc: Locate primary host function
  HostWrapper->>Analyzer: Walk body to find ttvm_call_packed
  Analyzer-->>HostWrapper: Extract function_params (device param list)
  HostWrapper->>Dispatch: Store function_informations{..., function_params}
  Dispatch->>CUDA: Generate launch code with function_params, desc_name_var_map
  CUDA->>CUDA: func_call_args(..., function_params, desc_name_var_map)
  CUDA->>CUDA: generate_tma_descriptor_args(desc_name_map, desc_name_var_map)
  note right of CUDA: Validate descriptor vars exist in tma_descriptor_args
  CUDA->>Kernel: Launch with mapped args and descriptors
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

In cables of code where kernels hop,
I twitch my whiskers—no need to stop.
Vars now guide the map’s keen sight,
Params align, descriptors right.
Host and device in tidy sync,
I thump approval—link by link.
Carrots compiled; time to blink. 🥕✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly and accurately summarizes the primary change: a TMA bugfix addressing a shared buffer used with both tma store and tma load. It reflects the PR objectives and the key code changes (descriptor mapping and init_desc_arg_map key type change), so a reviewer can quickly understand the intent.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical bug concerning the initialization of TMA (Tensor Memory Accelerator) descriptors, particularly when a shared buffer is simultaneously used for both load and store operations. The core of the fix involves transitioning from string-based identification to using unique tvm.tir.Var objects for mapping TMA descriptors. This ensures that the correct descriptors are consistently identified and passed during kernel launches, thereby improving the reliability and correctness of Hopper intrinsic lowering.

Highlights

  • TMA Descriptor Key Change: The key type for init_desc_arg_map in lower_hopper_intrin.cc has been updated from String to tvm.tir.Var. This change ensures unique identification of shared buffers, which is crucial for correct TMA descriptor initialization.
  • Enhanced Argument Mapping: The func_call_args method in TLCUDASourceWrapper now accepts additional parameters, function_params (a list of tvm.tir.Var objects) and desc_name_var_map, to provide more accurate and robust argument mapping during kernel launches.
  • Parameter-Argument Consistency Assertions: New assertions have been added to verify that the number of arguments passed during kernel launches consistently matches the number of function parameters, enhancing the overall robustness of the system.
  • Updated TMA Descriptor Initialization: The generate_tma_descriptor_args function has been modified to utilize the new mapping of tvm.tir.Var objects for TMA descriptor initialization, directly addressing the bug where shared buffers used for both load and store operations could lead to incorrect descriptor handling.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request provides a bugfix for TMA load/store operations on shared buffers by using Var objects instead of string names for descriptor mapping, which is a solid improvement. The changes also include adding assertions for better argument validation. My review focuses on improving code maintainability by addressing code duplication and fixing a recurring typo. I've suggested refactoring duplicated logic for kernel argument handling and for finding the primary host function into helper methods. These changes should make the code cleaner and easier to maintain.

Comment on lines 337 to 358
if self.use_cooperative_groups[function_name]:
args_list = func_call_args(declaration, function_args, desc_name_map)
args_list = func_call_args(declaration, function_args, function_params,
desc_name_map, desc_name_var_map)
assert len(function_params) == len(
args_list
), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments"
args_array = [f"(void*)&{arg}" for arg in args_list]
call_args = f"\tvoid* {function_name}_args[] = {{{', '.join(args_array)}}};\n"
kernel_launch_code += call_args
# Using cudaLaunchCooperativeKernel to launch the kernel
kernel_launch_code += "\tTILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));\n".format(
function_name, grid_str, block_str, function_name + "_args", smem_str)
else:
call_args = ", ".join(func_call_args(declaration, function_args, desc_name_map))
args_list = func_call_args(declaration, function_args, function_params,
desc_name_map, desc_name_var_map)
assert len(function_params) == len(
args_list
), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments"
call_args = ", ".join(args_list)
kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(
function_name, grid_str, block_str, smem_str, call_args)
kernel_launch_code += "\tTILELANG_CHECK_LAST_ERROR(\"{}\");\n".format(function_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's duplicated code for calling func_call_args and asserting the argument list length within the if/else block for cooperative groups. This can be refactored by moving the call and assertion before the if statement to improve maintainability and reduce redundancy.

            args_list = func_call_args(declaration, function_args, function_params,
                                       desc_name_map, desc_name_var_map)
            assert len(function_params) == len(
                args_list
            ), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments"
            if self.use_cooperative_groups[function_name]:
                args_array = [f"(void*)&{arg}" for arg in args_list]
                call_args = f"\tvoid* {function_name}_args[] = {{{', '.join(args_array)}}};\n"
                kernel_launch_code += call_args
                # Using cudaLaunchCooperativeKernel to launch the kernel
                kernel_launch_code += "\tTILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));\n".format(
                    function_name, grid_str, block_str, function_name + "_args", smem_str)
            else:
                call_args = ", ".join(args_list)
                kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(
                    function_name, grid_str, block_str, smem_str, call_args)
                kernel_launch_code += "\tTILELANG_CHECK_LAST_ERROR(\"{}\");\n".format(function_name)

def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str:
def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str],
desc_name_var_map: Dict[str, tvm.tir.Var]) -> str:
tma_descripter_init = ""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a recurring typo "descripter" in the variable name tma_descripter_init. It should be "descriptor". Please correct it to tma_descriptor_init for consistency and correctness. This typo also appears on lines 393, 442, and 448 within this function.

        tma_descriptor_init = ""

Comment on lines +631 to +642
@property
def host_func(self):
if len(self.host_mod.get_global_vars()) == 1:
return self.host_mod[self.host_mod.get_global_vars()[0]]
elif "main" in self.host_mod:
return self.host_mod["main"]
else:
for _, function in self.host_mod.functions.items():
attr = function.attrs
if "tir.is_global_func" in attr and attr["tir.is_global_func"]:
return function
raise ValueError("Cannot find primary function in the module.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic in the new host_func property is duplicated from device_func and prim_func. To improve maintainability and avoid redundancy, you could extract this logic into a private helper method. This helper could then be used by all three properties.

    def _get_primary_func(self, mod: IRModule):
        """Helper to find the primary function in a module."""
        if len(mod.get_global_vars()) == 1:
            return mod[mod.get_global_vars()[0]]
        if "main" in mod:
            return mod["main"]
        for _, function in mod.functions.items():
            attr = function.attrs
            if "tir.is_global_func" in attr and attr["tir.is_global_func"]:
                return function
        raise ValueError("Cannot find primary function in the module.")

    @property
    def host_func(self):
        return self._get_primary_func(self.host_mod)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tilelang/jit/adapter/wrapper.py (2)

288-295: Verify assertion message accuracy.

The assertion correctly validates that arguments don't exceed parameters. However, the assertion message could be more precise since this check happens during argument collection, not at the final count.

Consider updating the assertion message for clarity:

-                        assert len(call_args) <= len(
-                            function_params
-                        ), f"Function {function_name} has {len(function_params)} parameters, but {len(call_args)} arguments"
+                        assert len(call_args) <= len(
+                            function_params
+                        ), f"Function has {len(function_params)} parameters, but attempting to add argument #{len(call_args) + 1}"

Note that function_name is not in scope here, so the current message would cause a NameError if triggered.


631-643: Good addition of host_func property accessor.

The new host_func property provides clean access to the host function, maintaining consistency with existing prim_func and device_func properties. The implementation follows the established pattern for function resolution.

The static analysis tool suggests using dict.get on line 640, but the current implementation with the in check followed by direct access is actually clearer in this context since we want to explicitly check for the "main" key first.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bd16865 and cd74c68.

📒 Files selected for processing (2)
  • src/transform/lower_hopper_intrin.cc (2 hunks)
  • tilelang/jit/adapter/wrapper.py (8 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/lower_hopper_intrin.cc (4)
src/transform/storage_rewrite.cc (2)
  • var (1130-1153)
  • var (1130-1130)
src/transform/warp_specialized_rewriter.cc (2)
  • var (347-353)
  • var (347-347)
src/transform/cluster_planning.cc (2)
  • var (90-90)
  • var (90-90)
src/transform/merge_shared_memory_allocations.cc (6)
  • var (304-306)
  • var (304-304)
  • var (615-617)
  • var (615-615)
  • var (1054-1066)
  • var (1054-1054)
🪛 Ruff (0.13.1)
tilelang/jit/adapter/wrapper.py

580-582: Avoid specifying long messages outside the exception class

(TRY003)


640-640: Unnecessary key check before dictionary access

Replace with dict.get

(RUF019)


642-642: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: bot-task
🔇 Additional comments (8)
src/transform/lower_hopper_intrin.cc (2)

28-28: LGTM! Improved type safety for descriptor mapping.

Changing from Map<String, Array<PrimExpr>> to Map<Var, Array<PrimExpr>> provides better type safety and avoids potential naming conflicts. This ensures the descriptor arguments are directly associated with their corresponding variable objects rather than string-based name hints.


49-49: Consistent with the new Var-based mapping approach.

The change from Set(var->name_hint, init_desc_args) to Set(var, init_desc_args) properly aligns with the updated map type, ensuring the descriptor arguments are stored with the actual Var object as the key.

tilelang/jit/adapter/wrapper.py (6)

264-268: Good addition of parameter tracking for descriptor mapping.

The enhanced func_call_args signature now accepts function_params and desc_name_var_map to properly track and map descriptor variables to their corresponding function parameters. This is essential for the TMA descriptor handling improvements.


338-342: Consistent parameter validation for cooperative kernel launch.

Good addition of the assertion to ensure parameter-argument consistency for cooperative kernel launches, matching the validation done for regular kernel launches.


350-354: Proper validation for non-cooperative kernel launches.

The assertion correctly validates that the number of function parameters matches the arguments list before generating the kernel launch code.


362-363: Descriptor mapping properly propagated to TMA initialization.

The desc_name_var_map is correctly passed to generate_tma_descriptor_args, enabling proper resolution of descriptor handles through mapped variables.


389-399: Robust descriptor validation with proper error handling.

Excellent validation logic that:

  1. Verifies handle names exist in the descriptor-to-variable map
  2. Maps descriptors to their corresponding Var objects
  3. Validates that Var objects exist in the TMA descriptor arguments

This multi-level validation ensures consistency throughout the descriptor handling pipeline.


565-586: Well-structured parameter extraction from host module.

The implementation correctly:

  1. Retrieves the device function from the module
  2. Uses post-order traversal to find tvm_call_packed calls
  3. Extracts function parameters from the call site
  4. Validates parameter counts match expectations

The visitor pattern is properly implemented with appropriate error handling.

@LeiWang1999 LeiWang1999 merged commit b9a51c4 into tile-ai:main Sep 22, 2025
6 checks passed
RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
…ma load (tile-ai#857)

- Updated `init_desc_arg_map` to use `Var` as the key instead of `String` in `lower_hopper_intrin.cc`.
- Enhanced `func_call_args` method in `TLCUDASourceWrapper` to accept additional parameters for better argument mapping.
- Added assertions to ensure consistency between function parameters and arguments during kernel launches.
- Modified `generate_tma_descriptor_args` to utilize a mapping of variable names for TMA descriptor initialization.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant