File tree Expand file tree Collapse file tree 2 files changed +2
-3
lines changed
Expand file tree Collapse file tree 2 files changed +2
-3
lines changed Original file line number Diff line number Diff line change @@ -31,7 +31,6 @@ def gemm_sp(
3131 C (Union[tir.Buffer, tir.Var]): Output matrix for results
3232 transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
3333 transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
34- transpose_E (bool, optional): Whether to transpose matrix E. Defaults to False.
3534 policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.
3635 clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False.
3736 k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1.
Original file line number Diff line number Diff line change @@ -77,8 +77,8 @@ def gen_stride(shape_ik, order):
7777 shape_i , shape_k = shape_ik [:3 ], shape_ik [3 :]
7878 stride_i , stride_k = stride_ik [:3 ], stride_ik [3 :]
7979 elif bits_map [mma_dtype ] == 8 :
80- shape_i , shape_k = [64 ], [BlockK ]
81- stride_i , stride_k = [BlockK ], [1 ]
80+ shape_i , shape_k = [64 ], [block_k // 8 ]
81+ stride_i , stride_k = [block_k // 8 ], [1 ]
8282 else :
8383 raise NotImplementedError (f"Unknown mma type { mma_dtype } " )
8484
You can’t perform that action at this time.
0 commit comments