Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions tilelang/language/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
from tilelang import tvm as tvm
from tilelang.language import ptx_arrive_barrier, evaluate
from tilelang.language.kernel import get_thread_bindings, get_block_extents
from tilelang.utils.target import check_hip_availability
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The imported function check_hip_availability() is called every time one of the shuffle functions (shfl_xor, shfl_down, shfl_up) is invoked. The implementation of this check may perform file system lookups, which can be inefficient if called repeatedly. To improve performance, the result of this check should be cached at the module level.

For example:

# After imports
_IS_HIP_AVAILABLE = check_hip_availability()

# In shuffle functions
if _IS_HIP_AVAILABLE:
    # ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Don't select HIP/CUDA via host availability; dispatch by TVM Target.

check_hip_availability() reflects the build host, not the compilation target. On machines with both ROCm and CUDA installed (or during cross-compilation), this can emit HIP intrinsics while targeting CUDA (or vice versa), leading to compile errors or miscompiled kernels.

Introduce a target-aware helper and use it in the shfl wrappers.

+def _is_hip_target() -> bool:
+    tgt = tvm.target.Target.current(allow_none=True)
+    if tgt is not None:
+        kind = getattr(tgt, "kind", None)
+        name = getattr(kind, "name", "")
+        return name in ("rocm", "hip", "amdgpu")
+    # Fallback for contexts where Target is not set yet.
+    return check_hip_availability()

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In tilelang/language/builtin.py around line 6, the code imports and uses
check_hip_availability (which reflects the host) to choose HIP vs CUDA; replace
that with a TVM-target-aware helper. Add a small helper (e.g.,
is_target_hip(target) / resolve_accelerator_for_target(target)) that inspects
the provided TVM Target object or target string (checking target.kind.name or
target.arch/target.attrs or substring matching like "rocm"/"amdgcn" vs
"cuda"/"nvptx") and returns a boolean or enum indicating HIP vs CUDA; remove the
import of check_hip_availability and update the shfl wrapper functions to
accept/receive the TVM target (or derive it from context) and call the new
helper to decide which intrinsics to emit so selection is based on compilation
target rather than host availability.

from tvm import tir
from typing import Union, Any
from tvm.tir import PrimExpr, Var, Call

_IS_HIP_AVAILABLE = check_hip_availability()


def create_list_of_mbarrier(*args: Any) -> Call:
"""
Expand Down Expand Up @@ -295,7 +298,10 @@ def shfl_xor(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr,
Returns:
tir.Call: A handle to the shuffle operation
"""
return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset)
if _IS_HIP_AVAILABLE:
return tir.call_extern(value.dtype, "__shfl_xor", value, offset)
else:
return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset)


def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]):
Expand All @@ -305,7 +311,10 @@ def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr
value: Optional[int, PrimExpr]
The value to shuffle
"""
return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset)
if _IS_HIP_AVAILABLE:
return tir.call_extern(value.dtype, "__shfl_down", value, offset)
else:
return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset)


def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]):
Expand All @@ -315,7 +324,10 @@ def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr,
value: Optional[int, PrimExpr]
The value to shuffle
"""
return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset)
if _IS_HIP_AVAILABLE:
return tir.call_extern(value.dtype, "__shfl_up", value, offset)
else:
return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset)


def sync_threads():
Expand Down
Loading