Skip to content

Commit 34d1af6

Browse files
committed
[bugfix] fix wrong var
1 parent befe30e commit 34d1af6

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

tilelang/language/experimental/gemm_sp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def gemm_sp(
3131
C (Union[tir.Buffer, tir.Var]): Output matrix for results
3232
transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
3333
transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
34-
transpose_E (bool, optional): Whether to transpose matrix E. Defaults to False.
3534
policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.
3635
clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False.
3736
k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1.

tilelang/layout/gemm_sp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def gen_stride(shape_ik, order):
7777
shape_i, shape_k = shape_ik[:3], shape_ik[3:]
7878
stride_i, stride_k = stride_ik[:3], stride_ik[3:]
7979
elif bits_map[mma_dtype] == 8:
80-
shape_i, shape_k = [64], [BlockK]
81-
stride_i, stride_k = [BlockK], [1]
80+
shape_i, shape_k = [64], [block_k // 8]
81+
stride_i, stride_k = [block_k // 8], [1]
8282
else:
8383
raise NotImplementedError(f"Unknown mma type {mma_dtype}")
8484

0 commit comments

Comments
 (0)