File tree Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change 1010
1111def allow_tma_and_warp_specialized (pass_ctx : Optional [PassContext ] = None ,
1212 target : Optional [Target ] = None ) -> bool :
13+ # avoid circular import
14+ from tilelang .jit .adapter .utils import is_cuda_target
15+
1316 if pass_ctx is None :
1417 pass_ctx = tilelang .transform .get_pass_context ()
15- if not have_tma (target ):
18+ if not is_cuda_target ( target ) or not have_tma (target ):
1619 return False
1720 disable_tma_lower = pass_ctx .config .get ("tl.disable_tma_lower" , False )
1821 disable_tma_lower = pass_ctx .config .get ("tl.disable_tma_lower" , False )
@@ -21,7 +24,10 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
2124
2225
2326def allow_fence_proxy (target : Optional [Target ] = None ) -> bool :
24- return have_tma (target )
27+ # avoid circular import
28+ from tilelang .jit .adapter .utils import is_cuda_target
29+
30+ return is_cuda_target (target ) and have_tma (target )
2531
2632
2733def allow_vectorize (pass_ctx : Optional [PassContext ] = None ) -> bool :
You can’t perform that action at this time.
0 commit comments