Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change unittest to use flattened buffer. #18

Merged
merged 1 commit into from
Nov 13, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 19 additions & 21 deletions tests/python/sparsetir/test_tir_sparse_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@


@T.prim_func
def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle) -> None:
n = T.var("int32")
m = T.var("int32")
k = T.var("int32")
nnz = T.var("int32")
def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, k: T.int32, nnz: T.int32) -> None:
I = T.dense_fixed(m)
J = T.sparse_variable((n, m + 1, nnz), (indptr, indices), "int32")
K = T.dense_fixed(k)
Expand All @@ -40,26 +36,22 @@ def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.ha


@T.prim_func
def csrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle) -> None:
def csrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, M: T.int32, N: T.int32, K: T.int32, nnz: T.int32) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
n = T.var("int32")
m = T.var("int32")
k = T.var("int32")
nnz = T.var("int32")
A_data = T.match_buffer(a, (nnz,), "float32")
B = T.match_buffer(b, (n, k), "float32")
C = T.match_buffer(c, (m, k), "float32")
A_indptr = T.match_buffer(indptr, (m + 1,), "int32")
B = T.match_buffer(b, (N * K,), "float32")
C = T.match_buffer(c, (M * K,), "float32")
A_indptr = T.match_buffer(indptr, (M + 1,), "int32")
A_indices = T.match_buffer(indices, (nnz,), "int32")
for i, k in T.grid(m, k):
for i, k in T.grid(M, K):
with T.block("spmm_outer"):
vi, vk = T.axis.remap("SS", [i, k])
with T.init():
C[vi, vk] = 0.
C[vi * K + vk] = 0.
for j in T.serial(0, A_indptr[vi + 1] - A_indptr[vi]):
with T.block("spmm_inner"):
vj = T.axis.R(n, j + A_indptr[vi])
C[vi, vk] = C[vi, vk] + A_data[vj] * B[A_indices[vj], vk]
vj = T.axis.R(N, j + A_indptr[vi])
C[vi * K + vk] = C[vi * K + vk] + A_data[vj] * B[A_indices[vj] * K + vk]


def test_csrmm():
Expand All @@ -70,7 +62,12 @@ def test_csrmm():
y = np.zeros((4096, 256)).astype("float32")

# specialize function
sch = tir.Schedule(csrmm_tir)
_, _, _, _, _, m, n, k, nnz = csrmm_tir.params
sch = tir.Schedule(
csrmm_tir.specialize(
{m: 4096, n: 4096, k: 256, nnz: A.nnz}
)
)
blk_outer = sch.get_block("spmm_outer")
i, k = sch.get_loops(blk_outer)
sch.bind(i, "blockIdx.x")
Expand All @@ -80,14 +77,15 @@ def test_csrmm():
A_indptr = tvm.nd.array(A.indptr.astype("int32"), device=tvm.cuda(0))
A_indices = tvm.nd.array(A.indices.astype("int32"), device=tvm.cuda(0))
A_data = tvm.nd.array(A.data.astype("float32"), device=tvm.cuda(0))
X_nd = tvm.nd.array(x, device=tvm.cuda(0))
Y_nd = tvm.nd.array(y, device=tvm.cuda(0))
X_nd = tvm.nd.array(x.reshape(-1), device=tvm.cuda(0))
Y_nd = tvm.nd.array(y.reshape(-1), device=tvm.cuda(0))

# build function
f = tvm.build(sch.mod, target='cuda')
f(A_data, X_nd, Y_nd, A_indptr, A_indices)

assert np.allclose(y_ground_truth, Y_nd.numpy())
# assertion
assert np.allclose(y_ground_truth.reshape(-1), Y_nd.numpy())


if __name__ == "__main__":
Expand Down