-
Notifications
You must be signed in to change notification settings - Fork 333
[Cython] Remove an incorrect check #880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThe change disables the overlapping-tensor validation in tilelang/jit/adapter/cython/cython_wrapper.pyx by converting the check into a triple-quoted string, leaving input preparation logic otherwise intact. Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant CythonWrapper as Cython Wrapper
participant Kernel as Kernel
Caller->>CythonWrapper: prepare_inputs(args)
%% Previously performed validation (now disabled)
rect rgba(200,200,200,0.2)
note over CythonWrapper: Overlapping-tensor validation (removed)
end
CythonWrapper->>CythonWrapper: Build/prep tensor views
CythonWrapper->>Kernel: launch(args_prepared)
Kernel-->>Caller: result
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
Summary of ChangesHello @LJC00118, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses an issue with an existing validation check for Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request removes a check for overlapping tensors by commenting it out. While disabling the check seems correct based on the PR title, the method used (a multi-line string) is not ideal. I've suggested removing the code block and the associated TODO comment entirely to improve code clarity and maintainability.
| # TODO(chenggang): remove this check or rewrite by ourselves? | ||
| ''' | ||
| if isinstance(tensor, torch.Tensor) and tensor._base is not None and not tensor.is_contiguous(): | ||
| base_tensor = tensor._base.as_strided(tensor._base.shape, tensor.stride()) | ||
| if torch._debug_has_internal_overlap(base_tensor): | ||
| raise ValueError(f"Cannot use an overlapping tensor" | ||
| f"(shape={tensor.shape}, strides={tensor.stride()}, " | ||
| f"overlap={torch._debug_has_internal_overlap(base_tensor)}) as the kernel input") | ||
| ''' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
| ''' | ||
| if isinstance(tensor, torch.Tensor) and tensor._base is not None and not tensor.is_contiguous(): | ||
| base_tensor = tensor._base.as_strided(tensor._base.shape, tensor.stride()) | ||
| if torch._debug_has_internal_overlap(base_tensor): | ||
| raise ValueError(f"Cannot use an overlapping tensor" | ||
| f"(shape={tensor.shape}, strides={tensor.stride()}, " | ||
| f"overlap={torch._debug_has_internal_overlap(base_tensor)}) as the kernel input") | ||
| ''' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Restore the overlapping-tensor guard
Commenting out this block lets overlapping views (e.g., results of expand, as_strided, certain transposes) reach the kernel. Those views alias the same storage, so once the kernel writes, threads race and results become undefined. The underlying check was flawed because of the as_strided call on _base, but we still need a guard; we should just check the tensor directly and treat failures conservatively. Please rewrite instead of removing.
- '''
- if isinstance(tensor, torch.Tensor) and tensor._base is not None and not tensor.is_contiguous():
- base_tensor = tensor._base.as_strided(tensor._base.shape, tensor.stride())
- if torch._debug_has_internal_overlap(base_tensor):
- raise ValueError(f"Cannot use an overlapping tensor"
- f"(shape={tensor.shape}, strides={tensor.stride()}, "
- f"overlap={torch._debug_has_internal_overlap(base_tensor)}) as the kernel input")
- '''
+ if (
+ isinstance(tensor, torch.Tensor)
+ and tensor._base is not None
+ and not tensor.is_contiguous()
+ ):
+ try:
+ has_overlap = torch._debug_has_internal_overlap(tensor)
+ except RuntimeError as err:
+ raise ValueError(
+ f"Cannot use an overlapping tensor "
+ f"(shape={tensor.shape}, strides={tensor.stride()}) as the kernel input"
+ ) from err
+ if has_overlap:
+ raise ValueError(
+ f"Cannot use an overlapping tensor "
+ f"(shape={tensor.shape}, strides={tensor.stride()}, "
+ f"overlap={has_overlap}) as the kernel input"
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ''' | |
| if isinstance(tensor, torch.Tensor) and tensor._base is not None and not tensor.is_contiguous(): | |
| base_tensor = tensor._base.as_strided(tensor._base.shape, tensor.stride()) | |
| if torch._debug_has_internal_overlap(base_tensor): | |
| raise ValueError(f"Cannot use an overlapping tensor" | |
| f"(shape={tensor.shape}, strides={tensor.stride()}, " | |
| f"overlap={torch._debug_has_internal_overlap(base_tensor)}) as the kernel input") | |
| ''' | |
| if ( | |
| isinstance(tensor, torch.Tensor) | |
| and tensor._base is not None | |
| and not tensor.is_contiguous() | |
| ): | |
| try: | |
| has_overlap = torch._debug_has_internal_overlap(tensor) | |
| except RuntimeError as err: | |
| raise ValueError( | |
| f"Cannot use an overlapping tensor " | |
| f"(shape={tensor.shape}, strides={tensor.stride()}) as the kernel input" | |
| ) from err | |
| if has_overlap: | |
| raise ValueError( | |
| f"Cannot use an overlapping tensor " | |
| f"(shape={tensor.shape}, strides={tensor.stride()}, " | |
| f"overlap={has_overlap}) as the kernel input" | |
| ) |
|
LGTM, Merged. |
Summary by CodeRabbit