-
Notifications
You must be signed in to change notification settings - Fork 332
[AMD] fix bugs in warp shuffle #790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,10 +3,13 @@ | |
| from tilelang import tvm as tvm | ||
| from tilelang.language import ptx_arrive_barrier, evaluate | ||
| from tilelang.language.kernel import get_thread_bindings, get_block_extents | ||
| from tilelang.utils.target import check_hip_availability | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't select HIP/CUDA via host availability; dispatch by TVM Target. check_hip_availability() reflects the build host, not the compilation target. On machines with both ROCm and CUDA installed (or during cross-compilation), this can emit HIP intrinsics while targeting CUDA (or vice versa), leading to compile errors or miscompiled kernels. Introduce a target-aware helper and use it in the shfl wrappers. +def _is_hip_target() -> bool:
+ tgt = tvm.target.Target.current(allow_none=True)
+ if tgt is not None:
+ kind = getattr(tgt, "kind", None)
+ name = getattr(kind, "name", "")
+ return name in ("rocm", "hip", "amdgpu")
+ # Fallback for contexts where Target is not set yet.
+ return check_hip_availability()
🤖 Prompt for AI Agents |
||
| from tvm import tir | ||
| from typing import Union, Any | ||
| from tvm.tir import PrimExpr, Var, Call | ||
|
|
||
| _IS_HIP_AVAILABLE = check_hip_availability() | ||
|
|
||
|
|
||
| def create_list_of_mbarrier(*args: Any) -> Call: | ||
| """ | ||
|
|
@@ -295,7 +298,10 @@ def shfl_xor(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, | |
| Returns: | ||
| tir.Call: A handle to the shuffle operation | ||
| """ | ||
| return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset) | ||
| if _IS_HIP_AVAILABLE: | ||
| return tir.call_extern(value.dtype, "__shfl_xor", value, offset) | ||
| else: | ||
| return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset) | ||
|
|
||
|
|
||
| def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]): | ||
|
|
@@ -305,7 +311,10 @@ def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr | |
| value: Optional[int, PrimExpr] | ||
| The value to shuffle | ||
| """ | ||
| return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset) | ||
| if _IS_HIP_AVAILABLE: | ||
| return tir.call_extern(value.dtype, "__shfl_down", value, offset) | ||
| else: | ||
| return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset) | ||
|
|
||
|
|
||
| def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]): | ||
|
|
@@ -315,7 +324,10 @@ def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, | |
| value: Optional[int, PrimExpr] | ||
| The value to shuffle | ||
| """ | ||
| return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset) | ||
| if _IS_HIP_AVAILABLE: | ||
| return tir.call_extern(value.dtype, "__shfl_up", value, offset) | ||
| else: | ||
| return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset) | ||
|
|
||
|
|
||
| def sync_threads(): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The imported function
check_hip_availability()is called every time one of the shuffle functions (shfl_xor,shfl_down,shfl_up) is invoked. The implementation of this check may perform file system lookups, which can be inefficient if called repeatedly. To improve performance, the result of this check should be cached at the module level.For example: