Skip to content

Commit eb0462a

Browse files
gau-nernstLeiWang1999
authored andcommitted
[Bugfix] Check CUDA target before checking for TMA tile-ai#482
1 parent 1b2826e commit eb0462a

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tilelang/engine/phase.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88

99
def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
1010
target: Optional[Target] = None) -> bool:
11+
# avoid circular import
12+
from tilelang.jit.adapter.utils import is_cuda_target
13+
1114
if pass_ctx is None:
1215
pass_ctx = tilelang.transform.get_pass_context()
13-
if not have_tma(target):
16+
if not is_cuda_target(target) or not have_tma(target):
1417
return False
1518
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
1619
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
@@ -19,7 +22,10 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
1922

2023

2124
def allow_fence_proxy(target: Optional[Target] = None) -> bool:
22-
return have_tma(target)
25+
# avoid circular import
26+
from tilelang.jit.adapter.utils import is_cuda_target
27+
28+
return is_cuda_target(target) and have_tma(target)
2329

2430

2531
def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool:

0 commit comments

Comments
 (0)