Skip to content

Commit 961df37

Browse files
authored
[Bugfix] Check CUDA target before checking for TMA #482
1 parent d607ee2 commit 961df37

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tilelang/engine/phase.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010

1111
def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
1212
target: Optional[Target] = None) -> bool:
13+
# avoid circular import
14+
from tilelang.jit.adapter.utils import is_cuda_target
15+
1316
if pass_ctx is None:
1417
pass_ctx = tilelang.transform.get_pass_context()
15-
if not have_tma(target):
18+
if not is_cuda_target(target) or not have_tma(target):
1619
return False
1720
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
1821
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
@@ -21,7 +24,10 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
2124

2225

2326
def allow_fence_proxy(target: Optional[Target] = None) -> bool:
24-
return have_tma(target)
27+
# avoid circular import
28+
from tilelang.jit.adapter.utils import is_cuda_target
29+
30+
return is_cuda_target(target) and have_tma(target)
2531

2632

2733
def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool:

0 commit comments

Comments
 (0)