-
Notifications
You must be signed in to change notification settings - Fork 332
[TMA] Bugfix when a shared buffer is both issued with tma store and tma load #857
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…trin.cc - Updated `init_desc_arg_map` to use `Var` as the key instead of `String` in `lower_hopper_intrin.cc`. - Enhanced `func_call_args` method in `TLCUDASourceWrapper` to accept additional parameters for better argument mapping. - Added assertions to ensure consistency between function parameters and arguments during kernel launches. - Modified `generate_tma_descriptor_args` to utilize a mapping of variable names for TMA descriptor initialization.
WalkthroughRefactors LowerHopperIntrin to key descriptor-arg maps by Var instead of String. Extends CUDA wrapper to propagate device function parameters, map descriptor variables, and validate descriptor initialization. Adds host_func property, extracts function_params via post-order traversal, updates TMA descriptor generation, and wires mappings through dispatch and launch paths. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant HostWrapper as Host Wrapper
participant HostFunc as host_func (PrimFunc)
participant Analyzer as post_order_visit
participant Dispatch as create_dispatch_func
participant CUDA as CUDA Codegen
participant Kernel as Device Kernel
HostWrapper->>HostFunc: Locate primary host function
HostWrapper->>Analyzer: Walk body to find ttvm_call_packed
Analyzer-->>HostWrapper: Extract function_params (device param list)
HostWrapper->>Dispatch: Store function_informations{..., function_params}
Dispatch->>CUDA: Generate launch code with function_params, desc_name_var_map
CUDA->>CUDA: func_call_args(..., function_params, desc_name_var_map)
CUDA->>CUDA: generate_tma_descriptor_args(desc_name_map, desc_name_var_map)
note right of CUDA: Validate descriptor vars exist in tma_descriptor_args
CUDA->>Kernel: Launch with mapped args and descriptors
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
Summary of ChangesHello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a critical bug concerning the initialization of TMA (Tensor Memory Accelerator) descriptors, particularly when a shared buffer is simultaneously used for both load and store operations. The core of the fix involves transitioning from string-based identification to using unique Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request provides a bugfix for TMA load/store operations on shared buffers by using Var objects instead of string names for descriptor mapping, which is a solid improvement. The changes also include adding assertions for better argument validation. My review focuses on improving code maintainability by addressing code duplication and fixing a recurring typo. I've suggested refactoring duplicated logic for kernel argument handling and for finding the primary host function into helper methods. These changes should make the code cleaner and easier to maintain.
| if self.use_cooperative_groups[function_name]: | ||
| args_list = func_call_args(declaration, function_args, desc_name_map) | ||
| args_list = func_call_args(declaration, function_args, function_params, | ||
| desc_name_map, desc_name_var_map) | ||
| assert len(function_params) == len( | ||
| args_list | ||
| ), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments" | ||
| args_array = [f"(void*)&{arg}" for arg in args_list] | ||
| call_args = f"\tvoid* {function_name}_args[] = {{{', '.join(args_array)}}};\n" | ||
| kernel_launch_code += call_args | ||
| # Using cudaLaunchCooperativeKernel to launch the kernel | ||
| kernel_launch_code += "\tTILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));\n".format( | ||
| function_name, grid_str, block_str, function_name + "_args", smem_str) | ||
| else: | ||
| call_args = ", ".join(func_call_args(declaration, function_args, desc_name_map)) | ||
| args_list = func_call_args(declaration, function_args, function_params, | ||
| desc_name_map, desc_name_var_map) | ||
| assert len(function_params) == len( | ||
| args_list | ||
| ), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments" | ||
| call_args = ", ".join(args_list) | ||
| kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format( | ||
| function_name, grid_str, block_str, smem_str, call_args) | ||
| kernel_launch_code += "\tTILELANG_CHECK_LAST_ERROR(\"{}\");\n".format(function_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's duplicated code for calling func_call_args and asserting the argument list length within the if/else block for cooperative groups. This can be refactored by moving the call and assertion before the if statement to improve maintainability and reduce redundancy.
args_list = func_call_args(declaration, function_args, function_params,
desc_name_map, desc_name_var_map)
assert len(function_params) == len(
args_list
), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments"
if self.use_cooperative_groups[function_name]:
args_array = [f"(void*)&{arg}" for arg in args_list]
call_args = f"\tvoid* {function_name}_args[] = {{{', '.join(args_array)}}};\n"
kernel_launch_code += call_args
# Using cudaLaunchCooperativeKernel to launch the kernel
kernel_launch_code += "\tTILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));\n".format(
function_name, grid_str, block_str, function_name + "_args", smem_str)
else:
call_args = ", ".join(args_list)
kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(
function_name, grid_str, block_str, smem_str, call_args)
kernel_launch_code += "\tTILELANG_CHECK_LAST_ERROR(\"{}\");\n".format(function_name)| def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str: | ||
| def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str], | ||
| desc_name_var_map: Dict[str, tvm.tir.Var]) -> str: | ||
| tma_descripter_init = "" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| @property | ||
| def host_func(self): | ||
| if len(self.host_mod.get_global_vars()) == 1: | ||
| return self.host_mod[self.host_mod.get_global_vars()[0]] | ||
| elif "main" in self.host_mod: | ||
| return self.host_mod["main"] | ||
| else: | ||
| for _, function in self.host_mod.functions.items(): | ||
| attr = function.attrs | ||
| if "tir.is_global_func" in attr and attr["tir.is_global_func"]: | ||
| return function | ||
| raise ValueError("Cannot find primary function in the module.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic in the new host_func property is duplicated from device_func and prim_func. To improve maintainability and avoid redundancy, you could extract this logic into a private helper method. This helper could then be used by all three properties.
def _get_primary_func(self, mod: IRModule):
"""Helper to find the primary function in a module."""
if len(mod.get_global_vars()) == 1:
return mod[mod.get_global_vars()[0]]
if "main" in mod:
return mod["main"]
for _, function in mod.functions.items():
attr = function.attrs
if "tir.is_global_func" in attr and attr["tir.is_global_func"]:
return function
raise ValueError("Cannot find primary function in the module.")
@property
def host_func(self):
return self._get_primary_func(self.host_mod)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tilelang/jit/adapter/wrapper.py (2)
288-295: Verify assertion message accuracy.The assertion correctly validates that arguments don't exceed parameters. However, the assertion message could be more precise since this check happens during argument collection, not at the final count.
Consider updating the assertion message for clarity:
- assert len(call_args) <= len( - function_params - ), f"Function {function_name} has {len(function_params)} parameters, but {len(call_args)} arguments" + assert len(call_args) <= len( + function_params + ), f"Function has {len(function_params)} parameters, but attempting to add argument #{len(call_args) + 1}"Note that
function_nameis not in scope here, so the current message would cause a NameError if triggered.
631-643: Good addition of host_func property accessor.The new
host_funcproperty provides clean access to the host function, maintaining consistency with existingprim_funcanddevice_funcproperties. The implementation follows the established pattern for function resolution.The static analysis tool suggests using
dict.geton line 640, but the current implementation with theincheck followed by direct access is actually clearer in this context since we want to explicitly check for the "main" key first.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/transform/lower_hopper_intrin.cc(2 hunks)tilelang/jit/adapter/wrapper.py(8 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/lower_hopper_intrin.cc (4)
src/transform/storage_rewrite.cc (2)
var(1130-1153)var(1130-1130)src/transform/warp_specialized_rewriter.cc (2)
var(347-353)var(347-347)src/transform/cluster_planning.cc (2)
var(90-90)var(90-90)src/transform/merge_shared_memory_allocations.cc (6)
var(304-306)var(304-304)var(615-617)var(615-615)var(1054-1066)var(1054-1054)
🪛 Ruff (0.13.1)
tilelang/jit/adapter/wrapper.py
580-582: Avoid specifying long messages outside the exception class
(TRY003)
640-640: Unnecessary key check before dictionary access
Replace with dict.get
(RUF019)
642-642: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: bot-task
🔇 Additional comments (8)
src/transform/lower_hopper_intrin.cc (2)
28-28: LGTM! Improved type safety for descriptor mapping.Changing from
Map<String, Array<PrimExpr>>toMap<Var, Array<PrimExpr>>provides better type safety and avoids potential naming conflicts. This ensures the descriptor arguments are directly associated with their corresponding variable objects rather than string-based name hints.
49-49: Consistent with the new Var-based mapping approach.The change from
Set(var->name_hint, init_desc_args)toSet(var, init_desc_args)properly aligns with the updated map type, ensuring the descriptor arguments are stored with the actual Var object as the key.tilelang/jit/adapter/wrapper.py (6)
264-268: Good addition of parameter tracking for descriptor mapping.The enhanced
func_call_argssignature now acceptsfunction_paramsanddesc_name_var_mapto properly track and map descriptor variables to their corresponding function parameters. This is essential for the TMA descriptor handling improvements.
338-342: Consistent parameter validation for cooperative kernel launch.Good addition of the assertion to ensure parameter-argument consistency for cooperative kernel launches, matching the validation done for regular kernel launches.
350-354: Proper validation for non-cooperative kernel launches.The assertion correctly validates that the number of function parameters matches the arguments list before generating the kernel launch code.
362-363: Descriptor mapping properly propagated to TMA initialization.The
desc_name_var_mapis correctly passed togenerate_tma_descriptor_args, enabling proper resolution of descriptor handles through mapped variables.
389-399: Robust descriptor validation with proper error handling.Excellent validation logic that:
- Verifies handle names exist in the descriptor-to-variable map
- Maps descriptors to their corresponding Var objects
- Validates that Var objects exist in the TMA descriptor arguments
This multi-level validation ensures consistency throughout the descriptor handling pipeline.
565-586: Well-structured parameter extraction from host module.The implementation correctly:
- Retrieves the device function from the module
- Uses post-order traversal to find
tvm_call_packedcalls- Extracts function parameters from the call site
- Validates parameter counts match expectations
The visitor pattern is properly implemented with appropriate error handling.
…ma load (tile-ai#857) - Updated `init_desc_arg_map` to use `Var` as the key instead of `String` in `lower_hopper_intrin.cc`. - Enhanced `func_call_args` method in `TLCUDASourceWrapper` to accept additional parameters for better argument mapping. - Added assertions to ensure consistency between function parameters and arguments during kernel launches. - Modified `generate_tma_descriptor_args` to utilize a mapping of variable names for TMA descriptor initialization.
init_desc_arg_mapto useVaras the key instead ofStringinlower_hopper_intrin.cc.func_call_argsmethod inTLCUDASourceWrapperto accept additional parameters for better argument mapping.generate_tma_descriptor_argsto utilize a mapping of variable names for TMA descriptor initialization.Summary by CodeRabbit