Skip to content

Commit 2bd2d69

Browse files
authored
[Carver][Bugfix] Correct score function for warp tile selection in tensorcore policy (#724)
* [Carver][Bugfix] Correct score function for warp tile selection in tensorcore policy * [Typo] Correct architecture selection for CUDA and CDNA
1 parent 8e1b88f commit 2bd2d69

File tree

6 files changed

+7
-8
lines changed

6 files changed

+7
-8
lines changed

benchmark/matmul/benchmark_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def get_configs(args, kwargs):
5353
from tilelang.carver.roller.rasterization import NoRasterization
5454
import torch
5555

56-
arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
56+
arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
5757
topk = 10
5858

5959
carve_template = MatmulTemplate(

benchmark/matmul/benchmark_matmul_intrinsic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def get_configs(args, kwargs):
187187
from tilelang.carver.roller.rasterization import NoRasterization
188188
import torch
189189

190-
arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
190+
arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
191191
topk = 10
192192

193193
carve_template = MatmulTemplate(

examples/analyze/example_conv_analyze.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def conv(
9696

9797
def main():
9898
my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256)
99-
cuda_device = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
99+
cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
100100
result = Analyzer.analysis(my_func, cuda_device)
101101
print(result)
102102
print(f"Analyzed FLOPs: {result.total_flops}")

examples/analyze/example_gemm_analyze.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def matmul(
4949
def main():
5050
my_func = kernel(128, 128, 32, 3, 128, True)
5151

52-
cuda_device = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
52+
cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
5353
result = Analyzer.analysis(my_func, cuda_device)
5454

5555
print(f"Analyzed FLOPs: {result.total_flops}")

examples/gemm/example_gemm_autotune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def ref_program(A, B):
1616

1717
def get_configs(M, N, K, with_roller=False, topk=20):
1818
if with_roller:
19-
arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
19+
arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
2020
carve_template = MatmulTemplate(
2121
M=M,
2222
N=N,

tilelang/carver/roller/policy/tensorcore.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,10 +281,9 @@ def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int):
281281

282282
factors = factorize(np.prod(space) // warps)
283283

284-
def _score(node, thread): # small is better
284+
def _score(node, warp_tile): # small is better
285285
score = 0
286-
block_tile = [int(np.ceil(tile[i] / thread[i])) for i in range(ndim)]
287-
shape = node.propagate_inputs_on_reduction(block_tile)
286+
shape = node.propagate_inputs_on_reduction(warp_tile)
288287
input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block)
289288
for i, _ in enumerate(input_buffers):
290289
score += np.prod(shape[i]) / self.arch.bandwidth[1]

0 commit comments

Comments
 (0)