Skip to content

Commit

Permalink
upd (apache#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Nov 22, 2021
1 parent e432404 commit e6e3232
Showing 1 changed file with 94 additions and 0 deletions.
94 changes: 94 additions & 0 deletions tests/python/sparsetir/test_tir_sparse_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.tir as tir
import scipy.sparse as sp
import numpy as np
from tvm.script import tir as T


@T.prim_func
def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle) -> None:
n = T.var("int32")
m = T.var("int32")
k = T.var("int32")
nnz = T.var("int32")
I = T.dense_fixed(m)
J = T.sparse_variable((n, m + 1, nnz), (indptr, indices), "int32")
K = T.dense_fixed(k)
A = T.match_sparse_buffer(a, (I, J), nnz, "float32")
B = T.match_sparse_buffer(b, (T.to_dense(J), K), n * k, "float32")
C = T.match_sparse_buffer(c, (I, K), m * k, "float32")
with T.iter([T.cord(I), T.cord(J), T.cord(K)], "SRS", "csrmm") as [vi, vj, vk]:
with T.init():
C[vi, vk] = 0.0
C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk]


@T.prim_func
def csrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
n = T.var("int32")
m = T.var("int32")
k = T.var("int32")
nnz = T.var("int32")
A_data = T.match_buffer(a, (nnz,), "float32")
B = T.match_buffer(b, (n, k), "float32")
C = T.match_buffer(c, (m, k), "float32")
A_indptr = T.match_buffer(indptr, (m + 1,), "int32")
A_indices = T.match_buffer(indices, (nnz,), "int32")
for i, k in T.grid(m, k):
with T.block("spmm_outer"):
vi, vk = T.axis.remap("SS", [i, k])
with T.init():
C[vi, vk] = 0.
for j in T.serial(0, A_indptr[vi + 1] - A_indptr[vi]):
with T.block("spmm_inner"):
vj = T.axis.R(n, j + A_indptr[vi])
C[vi, vk] = C[vi, vk] + A_data[vj] * B[A_indices[vj], vk]


def test_csrmm():
# generate random input
A = sp.random(4096, 4096, dtype="float32", density=0.0125, format='csr')
x = np.random.rand(4096, 256).astype("float32")
y_ground_truth = A * x
y = np.zeros((4096, 256)).astype("float32")

# specialize function
sch = tir.Schedule(csrmm_tir)
blk_outer = sch.get_block("spmm_outer")
i, k = sch.get_loops(blk_outer)
sch.bind(i, "blockIdx.x")
sch.bind(k, "threadIdx.x")

# convert numpy tensor to tvm ndarray
A_indptr = tvm.nd.array(A.indptr.astype("int32"), device=tvm.cuda(0))
A_indices = tvm.nd.array(A.indices.astype("int32"), device=tvm.cuda(0))
A_data = tvm.nd.array(A.data.astype("float32"), device=tvm.cuda(0))
X_nd = tvm.nd.array(x, device=tvm.cuda(0))
Y_nd = tvm.nd.array(y, device=tvm.cuda(0))

# build function
f = tvm.build(sch.mod, target='cuda')
f(A_data, X_nd, Y_nd, A_indptr, A_indices)

assert np.allclose(y_ground_truth, Y_nd.numpy())


if __name__ == "__main__":
test_csrmm()

0 comments on commit e6e3232

Please sign in to comment.