-
Notifications
You must be signed in to change notification settings - Fork 332
[AMD] fix bugs in warp shuffle #790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds a HIP availability check and a module flag in Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant C as Caller
participant B as builtin.shuffle (wrapper)
participant F as _IS_HIP_AVAILABLE (module flag)
participant H as HIP intrinsic (__shfl_*)
participant U as CUDA intrinsic (__shfl_*_sync)
C->>B: shfl_{xor,down,up}(value, offset)
B->>F: read _IS_HIP_AVAILABLE
alt HIP available (true)
B->>H: __shfl_*(value, offset)
H-->>B: result
else HIP not available (false)
B->>U: __shfl_*_sync(0xffffffff, value, offset)
U-->>B: result
end
B-->>C: result
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 💡 Knowledge Base configuration:
You can enable these sources in your CodeRabbit configuration. 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @txs19991, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request resolves critical stability issues on AMD GPUs related to warp shuffle operations, which previously led to core dumps. By intelligently adapting the shuffle intrinsic calls based on the target platform, it ensures correct and more performant execution on AMD hardware, leveraging their unique architectural guarantees.
Highlights
- AMD Warp Shuffle Fix: Addresses core dumps on AMD GPUs when using
shfl_xor,shfl_down, andshfl_upby adapting the shuffle operations for AMD's 64-lane warps. - Conditional Shuffle Implementation: Implements conditional logic to use non-synchronizing shuffle intrinsics (
__shfl_xor,__shfl_down,__shfl_up) when running on AMD platforms (checked viacheck_hip_availability), while retaining_syncversions for other platforms. - Performance Improvement: Leverages AMD GPU's lockstep execution guarantee to use shuffle operations without additional synchronization, leading to improved performance.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly fixes a bug with warp shuffle operations on AMD GPUs by using the appropriate non-synchronized shuffle intrinsics for HIP environments. My review includes suggestions to improve code readability and performance.
Specifically, I've pointed out that check_hip_availability() is called multiple times, which could be inefficient. Caching the result at the module level would be a better approach. Additionally, the conditional logic in the shuffle functions is written using long ternary expressions, which harm readability. I've suggested refactoring these into if/else blocks.
Addressing these points will make the code more maintainable and performant.
| from tilelang import tvm as tvm | ||
| from tilelang.language import ptx_arrive_barrier, evaluate | ||
| from tilelang.language.kernel import get_thread_bindings, get_block_extents | ||
| from tilelang.utils.target import check_hip_availability |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The imported function check_hip_availability() is called every time one of the shuffle functions (shfl_xor, shfl_down, shfl_up) is invoked. The implementation of this check may perform file system lookups, which can be inefficient if called repeatedly. To improve performance, the result of this check should be cached at the module level.
For example:
# After imports
_IS_HIP_AVAILABLE = check_hip_availability()
# In shuffle functions
if _IS_HIP_AVAILABLE:
# ...
tilelang/language/builtin.py
Outdated
| tir.Call: A handle to the shuffle operation | ||
| """ | ||
| return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset) | ||
| return tir.call_extern(value.dtype, "__shfl_xor", value, offset) if check_hip_availability() else tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is quite long and can be difficult to read. Using an if/else block would improve code clarity and maintainability.
| return tir.call_extern(value.dtype, "__shfl_xor", value, offset) if check_hip_availability() else tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset) | |
| if check_hip_availability(): | |
| return tir.call_extern(value.dtype, "__shfl_xor", value, offset) | |
| else: | |
| return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset) |
tilelang/language/builtin.py
Outdated
| The value to shuffle | ||
| """ | ||
| return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset) | ||
| return tir.call_extern(value.dtype, "__shfl_down", value, offset) if check_hip_availability() else tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency and readability, this line should also be refactored into an if/else block, similar to the suggestion for shfl_xor.
| return tir.call_extern(value.dtype, "__shfl_down", value, offset) if check_hip_availability() else tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset) | |
| if check_hip_availability(): | |
| return tir.call_extern(value.dtype, "__shfl_down", value, offset) | |
| else: | |
| return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset) |
tilelang/language/builtin.py
Outdated
| The value to shuffle | ||
| """ | ||
| return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset) | ||
| return tir.call_extern(value.dtype, "__shfl_up", value, offset) if check_hip_availability() else tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To maintain a consistent and readable style across all shuffle functions, please refactor this line into an if/else block.
| return tir.call_extern(value.dtype, "__shfl_up", value, offset) if check_hip_availability() else tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset) | |
| if check_hip_availability(): | |
| return tir.call_extern(value.dtype, "__shfl_up", value, offset) | |
| else: | |
| return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (1)
tilelang/language/builtin.py (1)
299-319: Consider optional width to preserve 32-lane algorithms on AMD (wave64).HIP defaults width=warpSize (64 on AMD). If existing kernels assume 32-lane reductions, expose an optional width to force 32 when needed, e.g., __shfl_down(val, off, 32). Backward-compatible API:
-def shfl_xor(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]): +def shfl_xor(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call], width: Union[int, PrimExpr, None] = None): @@ - if _is_hip_target(): - return tir.call_extern(value.dtype, "__shfl_xor", value, offset) + if _is_hip_target(): + return tir.call_extern(value.dtype, "__shfl_xor", value, offset) if width is None \ + else tir.call_extern(value.dtype, "__shfl_xor", value, offset, width) @@ - return tir.call_extern(value.dtype, "__shfl_xor_sync", active, value, offset) + return tir.call_extern(value.dtype, "__shfl_xor_sync", active, value, offset) if width is None \ + else tir.call_extern(value.dtype, "__shfl_xor_sync", active, value, offset, width)Apply similarly to shfl_down and shfl_up. If you prefer to keep the public API unchanged, at least audit call sites that rely on 32-wide behavior. Would you like a quick grep script to find likely 32-lane reduction loops?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
tilelang/language/builtin.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/language/builtin.py (2)
tilelang/utils/target.py (1)
check_hip_availability(30-40)tilelang/language/tir/op.py (1)
call_extern(172-194)
| from tilelang import tvm as tvm | ||
| from tilelang.language import ptx_arrive_barrier, evaluate | ||
| from tilelang.language.kernel import get_thread_bindings, get_block_extents | ||
| from tilelang.utils.target import check_hip_availability |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't select HIP/CUDA via host availability; dispatch by TVM Target.
check_hip_availability() reflects the build host, not the compilation target. On machines with both ROCm and CUDA installed (or during cross-compilation), this can emit HIP intrinsics while targeting CUDA (or vice versa), leading to compile errors or miscompiled kernels.
Introduce a target-aware helper and use it in the shfl wrappers.
+def _is_hip_target() -> bool:
+ tgt = tvm.target.Target.current(allow_none=True)
+ if tgt is not None:
+ kind = getattr(tgt, "kind", None)
+ name = getattr(kind, "name", "")
+ return name in ("rocm", "hip", "amdgpu")
+ # Fallback for contexts where Target is not set yet.
+ return check_hip_availability()Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In tilelang/language/builtin.py around line 6, the code imports and uses
check_hip_availability (which reflects the host) to choose HIP vs CUDA; replace
that with a TVM-target-aware helper. Add a small helper (e.g.,
is_target_hip(target) / resolve_accelerator_for_target(target)) that inspects
the provided TVM Target object or target string (checking target.kind.name or
target.arch/target.attrs or substring matching like "rocm"/"amdgcn" vs
"cuda"/"nvptx") and returns a boolean or enum indicating HIP vs CUDA; remove the
import of check_hip_availability and update the shfl wrapper functions to
accept/receive the TVM target (or derive it from context) and call the new
helper to decide which intrinsics to emit so selection is based on compilation
target rather than host availability.
tilelang/language/builtin.py
Outdated
| tir.Call: A handle to the shuffle operation | ||
| """ | ||
| return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset) | ||
| return tir.call_extern(value.dtype, "__shfl_xor", value, offset) if check_hip_availability() else tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Use target-aware dispatch and pass the active mask (not 0xffffffff) on CUDA.
- Switch to _is_hip_target() (see previous comment).
- Use __activemask() to avoid deadlocks in divergent warps; 0xffffffff assumes all lanes active.
- return tir.call_extern(value.dtype, "__shfl_xor", value, offset) if check_hip_availability() else tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset)
+ if _is_hip_target():
+ return tir.call_extern(value.dtype, "__shfl_xor", value, offset)
+ active = tir.call_extern("uint32", "__activemask")
+ return tir.call_extern(value.dtype, "__shfl_xor_sync", active, value, offset)Note: If callers pass Python ints for value, value.dtype will fail. If that’s intended, consider wrapping ints with tir.const(..) at call sites or enforcing PrimExpr here.
🏁 Script executed:
#!/bin/bash
set -e
echo "Searching for check_hip_availability definitions and usages"
rg -n "def check_hip_availability" -C2 .
rg -n "check_hip_availability(" -C2 .
echo "Searching for _is_hip_target definitions and usages"
rg -n "def _is_hip_target" -C2 .
rg -n "_is_hip_target(" -C2 .
echo "Searching for __activemask usages"
rg -n "__activemask" -C2 .
echo "Inspecting shfl_xor implementation context"
rg -n "shfl_xor" -C3 tilelang/language/builtin.pyLength of output: 593
🏁 Script executed:
#!/bin/bash
set -e
echo "=== Opening tilelang/utils/target.py (lines 1-200) ==="
sed -n '1,200p' tilelang/utils/target.py || true
echo "=== Searching for is_hip_target or hip_target aliases ==="
rg -n "hip_target" -C2 tilelang/utils/target.py || true
rg -n "is_hip_target" -C2 tilelang/utils/target.py || true
echo "=== Searching for CUDA/ROCm availability utilities ==="
rg -n "check_cuda" -C2 tilelang/utils/target.py || true
rg -n "cuda" -C2 tilelang/utils/target.py || true
echo "=== Inspecting shfl_xor context in builtin.py (lines 250-350) ==="
sed -n '250,350p' tilelang/language/builtin.py || true
echo "=== Searching for other shuffle/intrinsic patterns ==="
rg -n "__shfl" -C5 tilelang/language/builtin.py || true
rg -n "activemask" -C2 tilelang/language/builtin.py || trueLength of output: 9113
🏁 Script executed:
#!/bin/bash
set -e
echo "=== Opening tilelang/language/builtin.py (lines 1-50) ==="
sed -n '1,50p' tilelang/language/builtin.pyLength of output: 1632
Use compile-time target dispatch and real active mask
Switch from the host‐availability check to inspecting the compilation target via tvm.target.Target.current(), and replace the hard-coded 0xffffffff with __activemask(). For example:
-from tilelang.utils.target import check_hip_availability
+from tilelang import tvm
def shfl_xor(value: Union[int, PrimExpr, tir.Call],
offset: Union[int, PrimExpr, tir.Call]):
- return tir.call_extern(value.dtype, "__shfl_xor", value, offset) \
- if check_hip_availability() \
- else tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset)
+ tgt = tvm.target.Target.current(allow_none=False)
+ if tgt.kind.name == "rocm":
+ return tir.call_extern(value.dtype, "__shfl_xor", value, offset)
+ active = tir.call_extern("uint32", "__activemask")
+ return tir.call_extern(value.dtype, "__shfl_xor_sync", active, value, offset)Apply the same pattern to shfl_down and shfl_up.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| return tir.call_extern(value.dtype, "__shfl_xor", value, offset) if check_hip_availability() else tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset) | |
| from tilelang import tvm | |
| def shfl_xor(value: Union[int, PrimExpr, tir.Call], | |
| offset: Union[int, PrimExpr, tir.Call]): | |
| - return tir.call_extern(value.dtype, "__shfl_xor", value, offset) \ | |
| - if check_hip_availability() \ | |
| tgt = tvm.target.Target.current(allow_none=False) | |
| if tgt.kind.name == "rocm": | |
| return tir.call_extern(value.dtype, "__shfl_xor", value, offset) | |
| active = tir.call_extern("uint32", "__activemask") | |
| return tir.call_extern(value.dtype, "__shfl_xor_sync", active, value, offset) |
🤖 Prompt for AI Agents
In tilelang/language/builtin.py around line 299, the current runtime host check
and hard-coded mask for calling CUDA/HIP shuffle intrinsics should be replaced
with compile-time target dispatch: use tvm.target.Target.current() to detect if
the target is an AMD HSA/ROCM target (HIP) versus NVPTX/CUDA, and call the
appropriate intrinsic name accordingly; also replace the literal 0xffffffff with
the device intrinsic __activemask() when emitting the sync variant. Make the
same changes in shfl_down and shfl_up: use Target.current() to choose
"__shfl_xor" vs "__shfl_xor_sync" (or corresponding down/up names), and use
__activemask() as the active mask argument for the _sync calls.
tilelang/language/builtin.py
Outdated
| The value to shuffle | ||
| """ | ||
| return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset) | ||
| return tir.call_extern(value.dtype, "__shfl_down", value, offset) if check_hip_availability() else tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Mirror the XOR fix for shfl_down.
Apply target-aware dispatch and active mask usage.
- return tir.call_extern(value.dtype, "__shfl_down", value, offset) if check_hip_availability() else tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset)
+ if _is_hip_target():
+ return tir.call_extern(value.dtype, "__shfl_down", value, offset)
+ active = tir.call_extern("uint32", "__activemask")
+ return tir.call_extern(value.dtype, "__shfl_down_sync", active, value, offset)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| return tir.call_extern(value.dtype, "__shfl_down", value, offset) if check_hip_availability() else tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset) | |
| if _is_hip_target(): | |
| return tir.call_extern(value.dtype, "__shfl_down", value, offset) | |
| active = tir.call_extern("uint32", "__activemask") | |
| return tir.call_extern(value.dtype, "__shfl_down_sync", active, value, offset) |
🤖 Prompt for AI Agents
In tilelang/language/builtin.py around line 309, the shfl_down call must mirror
the XOR fix: make the call target-aware and supply the CUDA active mask; replace
the current single-path call with a conditional that calls "__shfl_down" (no
mask) when HIP is active and calls "__shfl_down_sync" with the active mask
(0xffffffff) as the first argument when CUDA is active, and ensure the same
XOR-based lane handling used elsewhere is preserved for shfl_down.
tilelang/language/builtin.py
Outdated
| The value to shuffle | ||
| """ | ||
| return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset) | ||
| return tir.call_extern(value.dtype, "__shfl_up", value, offset) if check_hip_availability() else tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Mirror the XOR fix for shfl_up.
Apply target-aware dispatch and active mask usage.
- return tir.call_extern(value.dtype, "__shfl_up", value, offset) if check_hip_availability() else tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset)
+ if _is_hip_target():
+ return tir.call_extern(value.dtype, "__shfl_up", value, offset)
+ active = tir.call_extern("uint32", "__activemask")
+ return tir.call_extern(value.dtype, "__shfl_up_sync", active, value, offset)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| return tir.call_extern(value.dtype, "__shfl_up", value, offset) if check_hip_availability() else tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset) | |
| if _is_hip_target(): | |
| return tir.call_extern(value.dtype, "__shfl_up", value, offset) | |
| active = tir.call_extern("uint32", "__activemask") | |
| return tir.call_extern(value.dtype, "__shfl_up_sync", active, value, offset) |
🤖 Prompt for AI Agents
In tilelang/language/builtin.py around line 319, the shfl_up call must mirror
the XOR fix used for other shuffle ops by using target-aware dispatch and
providing the active mask for the sync variant; replace the single inline
expression with a branch that calls tir.call_extern(value.dtype, "__shfl_up",
value, offset) when check_hip_availability() is true, and
tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset) when
false, ensuring the call signature/order matches the other shfl_* helpers and
any XOR-based lane-index adjustment applied elsewhere is applied here too.
* [AMD] fix bugs in warp shuffle * format --------- Co-authored-by: tangxinsheng.txs <tangxinsheng.txs@alibaba-inc.com>
Each warp on AMD contains 64 lanes, so calling T.shfl_xor, T.shfl_down, and T.shfl_up causes a core dump. Moreover, AMD GPUs guarantee that all warp lanes are executed in lockstep; therefore, we use shuffle operations without additional synchronization, which provides better performance.
Summary by CodeRabbit