-
Notifications
You must be signed in to change notification settings - Fork 331
[Bugfix] Disable Memory Info Analysis for local.var
#851
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Updated the `AddWrapperForSingleBufStore` function to improve the handling of buffer stores by adding detailed checks for fragment buffer accesses and ensuring only index 0 is used. - Introduced new helper functions for collecting buffer accesses and indices, enhancing code readability and maintainability. - Refined the logic for determining tile operations and thread bindings to ensure accurate transformations without affecting existing parallel structures.
WalkthroughUpdates memory tagging logic to exclude ".var" from special-tagged handling in src/transform/storage_rewrite.cc. Rewrites tilelang AddWrapperForSingleBufStore to use depth-based tile loop analysis, explicit thread-binding tracking, buffer/indices inspection, and conditional BufferStore wrapping with a 0..1 parallel loop; registers the pass. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Pass as AddWrapperForSingleBufStore Pass
participant Func as PrimFunc
participant Visitor as IR Visitor
participant Stmt as Statement
participant Wrap as Wrapper Logic
Note over Pass,Func: Apply pass to function body
Pass->>Visitor: Initialize tile_operation_depth=0, thread_binding_vars=∅
loop Traverse IR
Visitor->>Stmt: PreVisit(node)
alt For loop
Note over Visitor: Detect tile-like loop (parallel or num_stages)
Visitor->>Visitor: Increment tile_operation_depth
Visitor->>Visitor: Track thread bindings (thread_extent)
end
Visitor-->>Stmt: Visit children
Visitor->>Stmt: PostVisit(node)
alt Node is BufferStore AND tile_operation_depth==0 AND no used vars in thread_binding_vars
Visitor->>Wrap: Analyze buffers and indices
alt Fragment buffers accessed AND all indices == 0
Wrap->>Stmt: Replace BufferStore with For(parallel, 0..1) { BufferStore }
else No fragment or invalid indices
Wrap-->>Stmt: Leave unchanged or raise error (non-zero fragment index)
end
else
Note over Visitor: No wrapping
end
alt Exiting tile-like loop
Visitor->>Visitor: Decrement tile_operation_depth
end
end
Visitor->>Pass: Return transformed body
Pass->>Func: func.with_body(new_body)
Note over Pass,Func: Pass registered via prim_func_pass(opt_level=0)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
Summary of ChangesHello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a bug by disabling memory information analysis for Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request provides a bugfix to disable memory info analysis for the local.var storage scope by excluding it from special tagged memory handling in src/transform/storage_rewrite.cc. Additionally, it includes a significant and valuable refactoring of the AddWrapperForSingleBufStore transform in tilelang/transform/add_bufstore_wrapper.py. The refactoring improves the transform's precision by targeting only specific BufferStore operations on local.fragment buffers, and adds better documentation and more robust checks. My review includes a couple of suggestions to further improve the Python code's clarity and safety.
| if isinstance(index, IntImm) and index != 0: | ||
| raise ValueError( | ||
| f"Fragment buffer access with non-zero index [{index}] is not supported. " | ||
| "Only fragment[0] access is allowed.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The validation for fragment buffer indices seems too permissive. It only raises an error for constant non-zero indices, but allows variable indices to pass through. This could lead to incorrect transformations if a variable index evaluates to a non-zero value at runtime. According to the docstring "Don't access fragment buffers with non-zero indices", the check should be stricter to only allow provably zero indices.
| if isinstance(index, IntImm) and index != 0: | |
| raise ValueError( | |
| f"Fragment buffer access with non-zero index [{index}] is not supported. " | |
| "Only fragment[0] access is allowed.") | |
| if not (isinstance(index, IntImm) and index.value == 0): | |
| raise ValueError( | |
| f"Fragment buffer access with non-constant or non-zero index [{index}] is not supported. " | |
| "Only fragment[0] access is allowed.") |
| def collect_buffer_accesses(statement) -> tuple[list[Buffer], list[Buffer]]: | ||
| """ | ||
| Categorizes buffers accessed in the statement by their scope. | ||
| Args: | ||
| statement: The TIR statement to analyze | ||
| Returns: | ||
| Tuple of (local_buffers, fragment_buffers) | ||
| """ | ||
| accessed_buffers = set() | ||
|
|
||
| def visit_buffer_access(node): | ||
| if isinstance(node, (BufferLoad, BufferStore)): | ||
| accessed_buffers.add(node.buffer) | ||
|
|
||
| post_order_visit(statement, visit_buffer_access) | ||
|
|
||
| local_buffers = [] | ||
| fragment_buffers = [] | ||
| for buffer in accessed_buffers: | ||
| if buffer.scope() == "local.fragment": | ||
| fragment_buffers.append(buffer) | ||
| elif buffer.scope().startswith("local"): | ||
| local_buffers.append(buffer) | ||
| return local_buffers, fragment_buffers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (7)
tilelang/transform/add_bufstore_wrapper.py (7)
1-1: Remove unused imports after refactoring.The import list has been expanded, but ensure all imports are actually used. Based on the code, all appear to be utilized except potentially some that might be indirectly referenced.
19-19: Remove unused function parameters.The
modandctxparameters are not used in the pass function. While these are part of the standard pass interface, you could use underscore prefixes to indicate they're intentionally unused.- def pass_fn(func: PrimFunc, mod, ctx): + def pass_fn(func: PrimFunc, _mod, _ctx):
25-42: Consider performance optimization for variable collection.The
get_used_variablesfunction creates a new set and visitor function for each call. For better performance in large ASTs, consider caching or reusing visitors.
44-69: Add null safety check for buffer scope.The
buffer.scope()method call could potentially return None. Consider adding a defensive check.for buffer in accessed_buffers: - if buffer.scope() == "local.fragment": + scope = buffer.scope() + if scope == "local.fragment": fragment_buffers.append(buffer) - elif buffer.scope().startswith("local"): + elif scope and scope.startswith("local"): local_buffers.append(buffer)
71-88: Potential issue with buffer indices collection.The function overwrites indices if the same buffer is accessed multiple times. Consider collecting all indices per buffer if multiple accesses are expected.
- buffer_to_indices = {} + from collections import defaultdict + buffer_to_indices = defaultdict(list) def visit_buffer_access(node): if isinstance(node, (BufferLoad, BufferStore)): - buffer_to_indices[node.buffer] = node.indices + buffer_to_indices[node.buffer].append(node.indices)Then adjust the validation logic accordingly in lines 134-141.
139-141: Consider a more specific exception type.Instead of a generic ValueError, consider using a more specific exception or creating a custom exception class for transformation errors.
102-120: Thread safety consideration for nonlocal variables.The use of
nonlocalvariables (tile_operation_depth) in visitor functions could be problematic if this pass is ever used in a multi-threaded context. Consider encapsulating the state in a class.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/transform/storage_rewrite.cc(1 hunks)tilelang/transform/add_bufstore_wrapper.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
tilelang/transform/add_bufstore_wrapper.py
19-19: Unused function argument: mod
(ARG001)
19-19: Unused function argument: ctx
(ARG001)
139-141: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: bot-task
🔇 Additional comments (3)
tilelang/transform/add_bufstore_wrapper.py (2)
6-17: Well-structured docstring!The addition of a comprehensive docstring clearly explains the pass's purpose and behavior, making the code more maintainable.
143-144: Well-designed transformation!The wrapping of fragment buffer accesses with a parallel loop is clean and follows TVM idioms correctly.
src/transform/storage_rewrite.cc (1)
674-677: Resolve — excluding ".var" is intentional and correct. The codebase treats ".var" as a local/scalar scope: IsSpecialTaggedMemory excludes ".var" (src/transform/storage_rewrite.cc), NewAllocTagMerged and lower_device_storage_access_info.cc skip GetMemoryInfo for ".var", and codegen handles "local.var" as a scalar/local variable (src/target/codegen_cuda.cc); therefore ".var" must not participate in special-memory merging or MemoryInfo bounds checks.
| buffer_indices = collect_buffer_indices(statement) | ||
| for buffer, indices in buffer_indices.items(): | ||
| if buffer.scope() == "local.fragment": | ||
| for index in indices: | ||
| if isinstance(index, IntImm) and index != 0: | ||
| raise ValueError( | ||
| f"Fragment buffer access with non-zero index [{index}] is not supported. " | ||
| "Only fragment[0] access is allowed.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Integer comparison issue with IntImm.
The comparison index != 0 compares an IntImm object with an integer. This should compare the value instead.
if buffer.scope() == "local.fragment":
for index in indices:
- if isinstance(index, IntImm) and index != 0:
+ if isinstance(index, IntImm) and index.value != 0:
raise ValueError(
f"Fragment buffer access with non-zero index [{index}] is not supported. "
"Only fragment[0] access is allowed.")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| buffer_indices = collect_buffer_indices(statement) | |
| for buffer, indices in buffer_indices.items(): | |
| if buffer.scope() == "local.fragment": | |
| for index in indices: | |
| if isinstance(index, IntImm) and index != 0: | |
| raise ValueError( | |
| f"Fragment buffer access with non-zero index [{index}] is not supported. " | |
| "Only fragment[0] access is allowed.") | |
| buffer_indices = collect_buffer_indices(statement) | |
| for buffer, indices in buffer_indices.items(): | |
| if buffer.scope() == "local.fragment": | |
| for index in indices: | |
| if isinstance(index, IntImm) and index.value != 0: | |
| raise ValueError( | |
| f"Fragment buffer access with non-zero index [{index}] is not supported. " | |
| "Only fragment[0] access is allowed.") |
🧰 Tools
🪛 Ruff (0.12.2)
139-141: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In tilelang/transform/add_bufstore_wrapper.py around lines 134 to 141, the
comparison `index != 0` is comparing an IntImm object to an int; change the
check to compare the IntImm's numeric value (e.g., `index.value != 0`) and
update the error message to include the numeric value (use `index.value`) so the
condition and message operate on the actual integer rather than the IntImm
object.
- Updated the `AddWrapperForSingleBufStore` function to improve the handling of buffer stores by adding detailed checks for fragment buffer accesses and ensuring only index 0 is used. - Introduced new helper functions for collecting buffer accesses and indices, enhancing code readability and maintainability. - Refined the logic for determining tile operations and thread bindings to ensure accurate transformations without affecting existing parallel structures.
bug fix for issue #836
Summary by CodeRabbit