Skip to content

Commit 94610e8

Browse files
committed
fix some remaining import issue
1 parent 9acab61 commit 94610e8

File tree

10 files changed

+47
-40
lines changed

10 files changed

+47
-40
lines changed

tilelang/cache/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""The cache utils with class and database persistence - Init file"""
22
from __future__ import annotations
33

4-
from typing import List, Union, Literal, Optional
4+
from typing import Literal
55
from tvm.target import Target
66
from tvm.tir import PrimFunc
77
from tilelang.jit import JITKernel

tilelang/carver/__init__.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,10 @@
11
"""Base infra"""
22

3-
from .analysis import (
4-
BlockInfo, # noqa: F401
5-
IterInfo, # noqa: F401
6-
collect_block_iter_vars_used_in_access_region, # noqa: F401
7-
collect_vars_used_in_prim_expr, # noqa: F401
8-
detect_dominant_read, # noqa: F401
9-
is_broadcast_epilogue, # noqa: F401
10-
normalize_prim_func, # noqa: F401
11-
) # noqa: F401
12-
from .common_schedules import (
13-
get_block,
14-
get_output_blocks,
15-
try_inline,
16-
try_inline_contiguous_spatial,
17-
) # noqa: F401
3+
from .analysis import *
4+
from .common_schedules import *
185
from .roller import *
19-
from .arch import CUDA, CDNA # noqa: F401
20-
from .template import (
21-
MatmulTemplate,
22-
GEMVTemplate,
23-
ElementwiseTemplate,
24-
GeneralReductionTemplate,
25-
FlashAttentionTemplate,
26-
) # noqa: F401
6+
from .arch import (
7+
CUDA as CUDA,
8+
CDNA as CDNA,
9+
)
10+
from .template import *

tilelang/carver/analysis.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,14 @@ def get_coalesced_veclen(block_stmt: tir.Block, target_bits: int = 128) -> int:
297297
for buffer in buffers:
298298
max_dtype_bits = max(max_dtype_bits, DataType(buffer.dtype).bits)
299299
return target_bits // max_dtype_bits
300+
301+
302+
__all__ = [
303+
BlockInfo,
304+
IterInfo,
305+
collect_block_iter_vars_used_in_access_region,
306+
collect_vars_used_in_prim_expr,
307+
detect_dominant_read,
308+
is_broadcast_epilogue,
309+
normalize_prim_func,
310+
]

tilelang/carver/arch/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from .cpu import *
55
from .cdna import *
66
from .metal import *
7-
from typing import Union
87
from tvm.target import Target
98
import torch
109

tilelang/carver/common_schedules.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,11 @@ def try_inline_contiguous_spatial(
163163
if spatial_blocks:
164164
results.extend(try_inline(sch, spatial_blocks))
165165
return results
166+
167+
168+
__all__ = [
169+
get_block,
170+
get_output_blocks,
171+
try_inline,
172+
try_inline_contiguous_spatial,
173+
]
Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
"""Template for the TileLang Carver."""
22

3-
from .base import BaseTemplate # noqa: F401
4-
from .matmul import MatmulTemplate # noqa: F401
5-
from .gemv import GEMVTemplate # noqa: F401
6-
from .elementwise import ElementwiseTemplate # noqa: F401
7-
from .general_reduce import GeneralReductionTemplate # noqa: F401
8-
from .flashattention import FlashAttentionTemplate # noqa: F401
9-
from .conv import ConvTemplate # noqa: F401
3+
from .base import BaseTemplate
4+
from .matmul import MatmulTemplate
5+
from .gemv import GEMVTemplate
6+
from .elementwise import ElementwiseTemplate
7+
from .general_reduce import GeneralReductionTemplate
8+
from .flashattention import FlashAttentionTemplate
9+
from .conv import ConvTemplate
10+
11+
__all__ = [
12+
'BaseTemplate',
13+
'MatmulTemplate',
14+
'GEMVTemplate',
15+
'ElementwiseTemplate',
16+
'GeneralReductionTemplate',
17+
'FlashAttentionTemplate',
18+
'ConvTemplate',
19+
]

tilelang/jit/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,9 @@
77

88
from typing import (
99
Any,
10-
List,
11-
Union,
1210
Callable,
13-
Tuple,
1411
overload,
1512
Literal,
16-
Dict, # For type hinting dicts
17-
Optional,
1813
)
1914
from tilelang import tvm as tvm
2015
from tilelang.jit.adapter.utils import is_metal_target

tilelang/language/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""The language interface for tl programs."""
22
from __future__ import annotations
33

4-
from typing import Optional, Callable, Dict
4+
from typing import Callable
55

66
# from .parser import *
77
# now is fully compatible with the upstream

tilelang/primitives/gemm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
from typing import Optional
2+
33
from tvm import tir
44
from tilelang.utils import is_local, is_fragment, is_shared
55
from tilelang.primitives.gemm.base import GemmWarpPolicy

tilelang/profiler/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""The profiler and convert to torch utils"""
22
from __future__ import annotations
33

4-
from typing import List, Optional, Callable, Any, Literal
4+
from typing import Callable, Any, Literal
55
from functools import partial
66
import torch
77
from contextlib import suppress

0 commit comments

Comments
 (0)