Skip to content

Commit d26fd59

Browse files
committed
[AMD] fix bugs in warp shuffle
1 parent 6e0c350 commit d26fd59

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

tilelang/language/builtin.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from tilelang import tvm as tvm
44
from tilelang.language import ptx_arrive_barrier, evaluate
55
from tilelang.language.kernel import get_thread_bindings, get_block_extents
6+
from tilelang.utils.target import check_hip_availability
67
from tvm import tir
78
from typing import Union, Any
89
from tvm.tir import PrimExpr, Var, Call
@@ -295,7 +296,7 @@ def shfl_xor(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr,
295296
Returns:
296297
tir.Call: A handle to the shuffle operation
297298
"""
298-
return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset)
299+
return tir.call_extern(value.dtype, "__shfl_xor", value, offset) if check_hip_availability() else tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset)
299300

300301

301302
def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]):
@@ -305,7 +306,7 @@ def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr
305306
value: Optional[int, PrimExpr]
306307
The value to shuffle
307308
"""
308-
return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset)
309+
return tir.call_extern(value.dtype, "__shfl_down", value, offset) if check_hip_availability() else tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset)
309310

310311

311312
def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]):
@@ -315,7 +316,7 @@ def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr,
315316
value: Optional[int, PrimExpr]
316317
The value to shuffle
317318
"""
318-
return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset)
319+
return tir.call_extern(value.dtype, "__shfl_up", value, offset) if check_hip_availability() else tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset)
319320

320321

321322
def sync_threads():

0 commit comments

Comments
 (0)