Skip to content

Commit 8361eb5

Browse files
zhangnjuzhangnju
andauthored
[Examples] Add the support of rocm arch detecting (#661)
Co-authored-by: zhangnju <ningzhan@SMC-SC-DI08-33.dh144.dcgpu>
1 parent d764dca commit 8361eb5

File tree

8 files changed

+50
-10
lines changed

8 files changed

+50
-10
lines changed

benchmark/matmul/benchmark_matmul.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,14 @@ def get_configs(args, kwargs):
4949
if with_roller:
5050
from tilelang.carver.template import MatmulTemplate
5151
from tilelang.carver.arch import CUDA
52+
from tilelang.carver.arch import CDNA
5253
from tilelang.carver.roller.rasterization import NoRasterization
53-
arch = CUDA("cuda")
54+
import torch
55+
56+
if torch.version.hip is not None:
57+
arch=CDNA("hip")
58+
else:
59+
arch = CUDA("cuda")
5460
topk = 10
5561

5662
carve_template = MatmulTemplate(

benchmark/matmul/benchmark_matmul_intrinsic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,14 @@ def get_configs(args, kwargs):
183183
if with_roller:
184184
from tilelang.carver.template import MatmulTemplate
185185
from tilelang.carver.arch import CUDA
186+
from tilelang.carver.arch import CDNA
186187
from tilelang.carver.roller.rasterization import NoRasterization
187-
arch = CUDA("cuda")
188+
import torch
189+
190+
if torch.version.hip is not None:
191+
arch=CDNA("hip")
192+
else:
193+
arch = CUDA("cuda")
188194
topk = 10
189195

190196
carve_template = MatmulTemplate(

benchmark/matmul_fp8/benchmark_matmul.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,14 @@ def get_configs(args, kwargs):
5050
if with_roller:
5151
from tilelang.carver.template import MatmulTemplate
5252
from tilelang.carver.arch import CUDA
53+
from tilelang.carver.arch import CDNA
5354
from tilelang.carver.roller.rasterization import NoRasterization
54-
arch = CUDA("cuda")
55+
import torch
56+
57+
if torch.version.hip is not None:
58+
arch=CDNA("hip")
59+
else:
60+
arch = CUDA("cuda")
5561
topk = 10
5662

5763
carve_template = MatmulTemplate(

examples/analyze/example_conv_analyze.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import tilelang.language as T
22
from tilelang.tools import Analyzer
33
from tilelang.carver.arch import CUDA
4+
from tilelang.carver.arch import CDNA
45
from tilelang.layout import make_swizzled_layout
5-
6+
import torch
67
N = 64
78
C = 256
89
H = 512
@@ -94,7 +95,10 @@ def conv(
9495

9596
def main():
9697
my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256)
97-
cuda_device = CUDA("cuda")
98+
if torch.version.hip is not None:
99+
cuda_device=CDNA("hip")
100+
else:
101+
cuda_device = CUDA("cuda")
98102
result = Analyzer.analysis(my_func, cuda_device)
99103
print(result)
100104
print(f"Analyzed FLOPs: {result.total_flops}")

examples/analyze/example_gemm_analyze.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import tilelang.language as T
22
from tilelang.tools import Analyzer
33
from tilelang.carver.arch import CUDA
4+
from tilelang.carver.arch import CDNA
5+
import torch
46

57
M = N = K = 1024
68

@@ -47,7 +49,10 @@ def matmul(
4749
def main():
4850
my_func = kernel(128, 128, 32, 3, 128, True)
4951

50-
cuda_device = CUDA("cuda")
52+
if torch.version.hip is not None:
53+
cuda_device=CDNA("hip")
54+
else:
55+
cuda_device = CUDA("cuda")
5156
result = Analyzer.analysis(my_func, cuda_device)
5257

5358
print(f"Analyzed FLOPs: {result.total_flops}")

examples/convolution/example_convolution_autotune.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from tilelang.autotuner import AutoTuner
77
from tilelang.carver.template import ConvTemplate
88
from tilelang.carver.arch import CUDA
9+
from tilelang.carver.arch import CDNA
910
from tilelang.carver.roller.rasterization import NoRasterization
1011

1112

@@ -31,7 +32,10 @@ def main(A, B):
3132

3233
def get_configs(N, C, H, W, F, K, S, D, P, with_roller=False, topk=15):
3334
if with_roller:
34-
arch = CUDA("cuda")
35+
if torch.version.hip is not None:
36+
arch=CDNA("hip")
37+
else:
38+
arch = CUDA("cuda")
3539
carve_template = ConvTemplate(
3640
N=N,
3741
C=C,

examples/gemm/example_gemm_autotune.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from tilelang.autotuner import AutoTuner
77
from tilelang.carver.template import MatmulTemplate
88
from tilelang.carver.arch import CUDA
9+
from tilelang.carver.arch import CDNA
910
from tilelang.carver.roller.rasterization import NoRasterization
1011

1112

@@ -15,7 +16,10 @@ def ref_program(A, B):
1516

1617
def get_configs(M, N, K, with_roller=False, topk=20):
1718
if with_roller:
18-
arch = CUDA("cuda")
19+
if torch.version.hip is not None:
20+
arch=CDNA("hip")
21+
else:
22+
arch = CUDA("cuda")
1923
carve_template = MatmulTemplate(
2024
M=M,
2125
N=N,

tilelang/carver/arch/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .cdna import CDNA
55
from typing import Union
66
from tvm.target import Target
7-
7+
import torch
88

99
def get_arch(target: Union[str, Target] = "cuda") -> TileDevice:
1010
if isinstance(target, str):
@@ -23,7 +23,12 @@ def get_arch(target: Union[str, Target] = "cuda") -> TileDevice:
2323
def auto_infer_current_arch() -> TileDevice:
2424
# TODO(lei): This is a temporary solution to infer the current architecture
2525
# Can be replaced by a more sophisticated method in the future
26-
return get_arch("cuda")
26+
if torch.version.hip is not None:
27+
return get_arch("hip")
28+
if torch.cuda.is_available():
29+
return get_arch("cuda")
30+
else:
31+
return get_arch("llvm")
2732

2833

2934
from .cpu import is_cpu_arch # noqa: F401

0 commit comments

Comments
 (0)