diff --git a/tests/python/sparsetir/test_tir_sparse_correctness.py b/tests/python/sparsetir/test_tir_sparse_correctness.py index 38ce31afd73d..9457d17de0b9 100644 --- a/tests/python/sparsetir/test_tir_sparse_correctness.py +++ b/tests/python/sparsetir/test_tir_sparse_correctness.py @@ -22,11 +22,7 @@ @T.prim_func -def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle) -> None: - n = T.var("int32") - m = T.var("int32") - k = T.var("int32") - nnz = T.var("int32") +def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, k: T.int32, nnz: T.int32) -> None: I = T.dense_fixed(m) J = T.sparse_variable((n, m + 1, nnz), (indptr, indices), "int32") K = T.dense_fixed(k) @@ -40,26 +36,22 @@ def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.ha @T.prim_func -def csrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle) -> None: +def csrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, M: T.int32, N: T.int32, K: T.int32, nnz: T.int32) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - n = T.var("int32") - m = T.var("int32") - k = T.var("int32") - nnz = T.var("int32") A_data = T.match_buffer(a, (nnz,), "float32") - B = T.match_buffer(b, (n, k), "float32") - C = T.match_buffer(c, (m, k), "float32") - A_indptr = T.match_buffer(indptr, (m + 1,), "int32") + B = T.match_buffer(b, (N * K,), "float32") + C = T.match_buffer(c, (M * K,), "float32") + A_indptr = T.match_buffer(indptr, (M + 1,), "int32") A_indices = T.match_buffer(indices, (nnz,), "int32") - for i, k in T.grid(m, k): + for i, k in T.grid(M, K): with T.block("spmm_outer"): vi, vk = T.axis.remap("SS", [i, k]) with T.init(): - C[vi, vk] = 0. + C[vi * K + vk] = 0. for j in T.serial(0, A_indptr[vi + 1] - A_indptr[vi]): with T.block("spmm_inner"): - vj = T.axis.R(n, j + A_indptr[vi]) - C[vi, vk] = C[vi, vk] + A_data[vj] * B[A_indices[vj], vk] + vj = T.axis.R(N, j + A_indptr[vi]) + C[vi * K + vk] = C[vi * K + vk] + A_data[vj] * B[A_indices[vj] * K + vk] def test_csrmm(): @@ -70,7 +62,12 @@ def test_csrmm(): y = np.zeros((4096, 256)).astype("float32") # specialize function - sch = tir.Schedule(csrmm_tir) + _, _, _, _, _, m, n, k, nnz = csrmm_tir.params + sch = tir.Schedule( + csrmm_tir.specialize( + {m: 4096, n: 4096, k: 256, nnz: A.nnz} + ) + ) blk_outer = sch.get_block("spmm_outer") i, k = sch.get_loops(blk_outer) sch.bind(i, "blockIdx.x") @@ -80,14 +77,15 @@ def test_csrmm(): A_indptr = tvm.nd.array(A.indptr.astype("int32"), device=tvm.cuda(0)) A_indices = tvm.nd.array(A.indices.astype("int32"), device=tvm.cuda(0)) A_data = tvm.nd.array(A.data.astype("float32"), device=tvm.cuda(0)) - X_nd = tvm.nd.array(x, device=tvm.cuda(0)) - Y_nd = tvm.nd.array(y, device=tvm.cuda(0)) + X_nd = tvm.nd.array(x.reshape(-1), device=tvm.cuda(0)) + Y_nd = tvm.nd.array(y.reshape(-1), device=tvm.cuda(0)) # build function f = tvm.build(sch.mod, target='cuda') f(A_data, X_nd, Y_nd, A_indptr, A_indices) - assert np.allclose(y_ground_truth, Y_nd.numpy()) + # assertion + assert np.allclose(y_ground_truth.reshape(-1), Y_nd.numpy()) if __name__ == "__main__":