-
Notifications
You must be signed in to change notification settings - Fork 332
[CI] Test Fix: Handle BufferLoad nodes when T.gemm input has a stride #843
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
You have run out of free Bugbot PR reviews for this billing cycle. This will reset on October 13. To receive reviews on all of your PRs, visit the Cursor dashboard to activate Pro and start your 14-day free trial. |
|
Caution Review failedThe pull request is closed. WalkthroughAdds tir.BufferLoad support to GEMM helper functions by deriving regions for shape, stride, pointer, and offset retrieval. Updates error messages to explicit ValueError variants. Applies changes to both gemm and gemm_v2; gemm_v2 switches final intrinsic to tl.gemm_py. Imports get_buffer_region_from_load. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant U as User Code
participant G as gemm/gemm_v2
participant H as retrieve_* helpers
participant B as Buffer/Region
participant I as tl.gemm / tl.gemm_py
U->>G: call with A,B,C (may be tir.BufferLoad)
G->>H: retrieve_shape/stride/ptr/offset(A,B,C)
alt arg is tir.BufferLoad
H->>B: get_buffer_region_from_load(arg)
Note right of H: Derive region mins/extents<br/>Compute shape/strides/offset
H-->>G: shapes, strides, ptr via access_ptr(offset)
else arg is Buffer/var
H-->>G: shapes, strides, ptr, offset (existing paths)
end
G->>I: invoke intrinsic (gemm_v2 -> gemm_py)
I-->>G: result / side-effects
G-->>U: completion
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
✨ Finishing touches
🧪 Generate unit tests
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request resolves a test failure by implementing comprehensive support for tir.BufferLoad nodes within the gemm module's helper functions. The modifications ensure that when T.gemm receives inputs that are BufferLoad objects, their shapes, strides, pointers, and offsets are correctly identified and handled, particularly when dealing with strided memory access patterns. This improves the robustness and correctness of the gemm operation under more complex input conditions.
Highlights
- Enhanced
gemmUtility Functions: The utility functionsretrieve_shape,retrieve_stride,retrieve_ptr, andretrieve_offsetwithintilelang/language/gemm.pyhave been extended to correctly processtir.BufferLoadobjects. - Support for Strided Buffer Inputs: This change specifically addresses scenarios where
T.gemminputs aretir.BufferLoadnodes that might represent strided buffers, ensuring accurate calculations for their properties. - New Import for BufferLoad Handling: The
get_buffer_region_from_loadutility function has been imported and integrated to facilitate the extraction of buffer region information fromtir.BufferLoadobjects.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in pull request comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for tir.BufferLoad nodes as inputs to T.gemm by extending the internal helper functions in both gemm and gemm_v2.
While the change correctly extends functionality, I've identified several areas for improvement:
- Potential
AttributeError: The new logic inretrieve_shape,retrieve_ptr, andretrieve_offsetdoes not handle the case whereget_buffer_region_from_loadreturnsNone, which can lead to anAttributeError. I've added comments with suggestions to fix this potential bug. - Code Duplication: There is significant code duplication within the helper functions (e.g., between
tir.BufferRegionandtir.BufferLoadcases) and acrossgemmandgemm_v2(the entire set of helper functions is duplicated).
I highly recommend refactoring to address the duplication. A good first step would be to move the helper functions (retrieve_shape, retrieve_stride, etc.) outside of gemm and gemm_v2 so they can be shared. This will greatly improve the code's maintainability and readability. My specific comments focus on the added code blocks and provide suggestions for immediate fixes and local refactoring.
| elif isinstance(object, tir.BufferLoad): | ||
| region = get_buffer_region_from_load(object).region | ||
| shape = [] | ||
| for r in region: | ||
| shape.append(r.extent) | ||
| return shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new block to handle tir.BufferLoad introduces a potential AttributeError and duplicates code.
-
Potential Bug: The function
get_buffer_region_from_loadcan returnNone. If it does,get_buffer_region_from_load(object).regionwill raise anAttributeError. You should add a check for this. -
Code Duplication: The logic in this block is nearly identical to the logic in the
elif isinstance(object, tir.BufferRegion)block. This makes the code harder to maintain. Ideally, the common logic should be extracted.
The suggestion below fixes the potential bug and simplifies the list creation using a list comprehension.
| elif isinstance(object, tir.BufferLoad): | |
| region = get_buffer_region_from_load(object).region | |
| shape = [] | |
| for r in region: | |
| shape.append(r.extent) | |
| return shape | |
| elif isinstance(object, tir.BufferLoad): | |
| buffer_region = get_buffer_region_from_load(object) | |
| if not buffer_region: | |
| raise ValueError(f"Could not get BufferRegion from BufferLoad: {object}") | |
| return [r.extent for r in buffer_region.region] |
| elif isinstance(object, tir.BufferLoad): | ||
| buffer = object.buffer | ||
| region = get_buffer_region_from_load(object).region | ||
| indices = [] | ||
| for r in region: | ||
| indices.append(r.min) | ||
| strides = [] | ||
| stride = 1 | ||
| for s in reversed(buffer.shape): | ||
| strides.insert(0, stride) | ||
| stride *= s | ||
| offset = 0 | ||
| for i in range(len(indices) - 2): | ||
| offset += indices[i] * strides[i] | ||
| return buffer.access_ptr(access_mask=access_type, offset=offset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This large block of code has a few issues:
-
Potential Bug:
get_buffer_region_from_load(object)can returnNone, which would cause anAttributeErroron the next line (...region). ANonecheck is required. -
Code Duplication: This block is almost a copy-paste of the
tir.BufferRegioncase. Furthermore, it duplicates logic fromretrieve_stride(for calculating strides) andretrieve_offset(for calculating indices).
This significant duplication makes the code difficult to maintain. I strongly recommend refactoring. You could start by fixing the bug by adding a None check after calling get_buffer_region_from_load. Then, consider a broader refactoring to reduce duplication, for example by extracting common parts into helper functions.
| elif isinstance(object, tir.BufferLoad): | ||
| region = get_buffer_region_from_load(object).region | ||
| indices = [] | ||
| for r in region: | ||
| indices.append(r.min) | ||
| return indices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to retrieve_shape, this block has a potential AttributeError if get_buffer_region_from_load returns None, and it duplicates logic from the tir.BufferRegion case.
Please add a None check and consider refactoring to reduce code duplication. The suggestion below fixes the bug and uses a list comprehension for conciseness.
| elif isinstance(object, tir.BufferLoad): | |
| region = get_buffer_region_from_load(object).region | |
| indices = [] | |
| for r in region: | |
| indices.append(r.min) | |
| return indices | |
| elif isinstance(object, tir.BufferLoad): | |
| buffer_region = get_buffer_region_from_load(object) | |
| if not buffer_region: | |
| raise ValueError(f"Could not get BufferRegion from BufferLoad: {object}") | |
| return [r.min for r in buffer_region.region] |
| elif isinstance(object, tir.BufferLoad): | ||
| region = get_buffer_region_from_load(object).region | ||
| shape = [] | ||
| for r in region: | ||
| shape.append(r.extent) | ||
| return shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new block to handle tir.BufferLoad introduces a potential AttributeError and duplicates code.
-
Potential Bug: The function
get_buffer_region_from_loadcan returnNone. If it does,get_buffer_region_from_load(object).regionwill raise anAttributeError. You should add a check for this. -
Code Duplication: The logic in this block is nearly identical to the logic in the
elif isinstance(object, tir.BufferRegion)block. This makes the code harder to maintain. Ideally, the common logic should be extracted.
The suggestion below fixes the potential bug and simplifies the list creation using a list comprehension.
| elif isinstance(object, tir.BufferLoad): | |
| region = get_buffer_region_from_load(object).region | |
| shape = [] | |
| for r in region: | |
| shape.append(r.extent) | |
| return shape | |
| elif isinstance(object, tir.BufferLoad): | |
| buffer_region = get_buffer_region_from_load(object) | |
| if not buffer_region: | |
| raise ValueError(f"Could not get BufferRegion from BufferLoad: {object}") | |
| return [r.extent for r in buffer_region.region] |
| elif isinstance(object, tir.BufferLoad): | ||
| buffer = object.buffer | ||
| region = get_buffer_region_from_load(object).region | ||
| indices = [] | ||
| for r in region: | ||
| indices.append(r.min) | ||
| strides = [] | ||
| stride = 1 | ||
| for s in reversed(buffer.shape): | ||
| strides.insert(0, stride) | ||
| stride *= s | ||
| offset = 0 | ||
| for i in range(len(indices) - 2): | ||
| offset += indices[i] * strides[i] | ||
| return buffer.access_ptr(access_mask=access_type, offset=offset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This large block of code has a few issues:
-
Potential Bug:
get_buffer_region_from_load(object)can returnNone, which would cause anAttributeErroron the next line (...region). ANonecheck is required. -
Code Duplication: This block is almost a copy-paste of the
tir.BufferRegioncase. Furthermore, it duplicates logic fromretrieve_stride(for calculating strides) andretrieve_offset(for calculating indices).
This significant duplication makes the code difficult to maintain. I strongly recommend refactoring. You could start by fixing the bug by adding a None check after calling get_buffer_region_from_load. Then, consider a broader refactoring to reduce duplication, for example by extracting common parts into helper functions.
| elif isinstance(object, tir.BufferLoad): | ||
| region = get_buffer_region_from_load(object).region | ||
| indices = [] | ||
| for r in region: | ||
| indices.append(r.min) | ||
| return indices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to retrieve_shape, this block has a potential AttributeError if get_buffer_region_from_load returns None, and it duplicates logic from the tir.BufferRegion case.
Please add a None check and consider refactoring to reduce code duplication. The suggestion below fixes the bug and uses a list comprehension for conciseness.
| elif isinstance(object, tir.BufferLoad): | |
| region = get_buffer_region_from_load(object).region | |
| indices = [] | |
| for r in region: | |
| indices.append(r.min) | |
| return indices | |
| elif isinstance(object, tir.BufferLoad): | |
| buffer_region = get_buffer_region_from_load(object) | |
| if not buffer_region: | |
| raise ValueError(f"Could not get BufferRegion from BufferLoad: {object}") | |
| return [r.min for r in buffer_region.region] |
| elif isinstance(object, tir.BufferLoad): | ||
| buffer = object.buffer | ||
| strides = [] | ||
| stride = 1 | ||
| for s in reversed(buffer.shape): | ||
| strides.insert(0, stride) | ||
| stride *= s | ||
| return strides |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block for tir.BufferLoad duplicates the stride calculation logic that already exists for tir.Buffer and tir.BufferRegion. This increases maintenance overhead.
To improve maintainability, you could refactor retrieve_stride to extract this common logic into a helper function or restructure the conditional to handle all buffer-like types that have a .shape attribute in a single path.
| elif isinstance(object, tir.BufferLoad): | ||
| buffer = object.buffer | ||
| strides = [] | ||
| stride = 1 | ||
| for s in reversed(buffer.shape): | ||
| strides.insert(0, stride) | ||
| stride *= s | ||
| return strides |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block for tir.BufferLoad duplicates the stride calculation logic that already exists for tir.Buffer and tir.BufferRegion. This increases maintenance overhead.
To improve maintainability, you could refactor retrieve_stride to extract this common logic into a helper function or restructure the conditional to handle all buffer-like types that have a .shape attribute in a single path.
…tile-ai#843) * bugfix * fix * test fix
Summary by CodeRabbit