Skip to content

Commit

Permalink
[7] Review comments handled
Browse files Browse the repository at this point in the history
  • Loading branch information
ANSHUMAN TRIPATHY authored and ANSHUMAN TRIPATHY committed Nov 20, 2020
1 parent 6825222 commit a466aed
Showing 1 changed file with 25 additions and 12 deletions.
37 changes: 25 additions & 12 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1994,17 +1994,26 @@ def batch_matmul(x, y):


# pylint: disable=no-else-return,inconsistent-return-statements
def sparse_dense(data, weight, sparse_lhs=False):
def sparse_dense(dense_mat, sparse_mat, sparse_lhs=False):
r"""
Computes the matrix multiplication of `data` and `weight`, where `data` is
a dense matrix and `weight` is a sparse (either BSR or CSR) namedtuple with
Computes the matrix multiplication of `dense_mat` and `sparse_mat`, where `dense_mat` is
a dense matrix and `sparse_mat` is a sparse (either BSR or CSR) namedtuple with
fields `data`, `indices`, and `indptr`.
.. math::
\mbox{sparse_dense}(data, weight)[m, n] = \mbox{matmul}(x, \mbox{as_dense}(weight)^T)[m, n]
if sparse_lhs=True
where `as_dense` returns dense equivalent of the given sparse matrix.
\mbox{sparse_dense}(dense_mat, sparse_mat)[m, n]
= \mbox{matmul}(D, \mbox{as_dense}(S)^T)[m, n]
if sparse_lhs=False
\mbox{sparse_dense}(dense_mat, sparse_mat)[m, n]
= \mbox{matmul}(\mbox{as_dense}(S), (D)^T)[m, n]
where `as_dense` returns dense equivalent of the given S(sparse matrix)
while performing matmul with given D(dense matrix).
See
https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html
Expand All @@ -2014,11 +2023,11 @@ def sparse_dense(data, weight, sparse_lhs=False):
Parameters
----------
data : tvm.relay.Expr
The input data for the matrix multiplication
dense_mat : tvm.relay.Expr
The input dense matrix for the matrix multiplication
weight : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]].
The sparse weight matrix for the matrix multiplication.
sparse_mat : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]].
The input sparse matrix for the matrix multiplication.
sparse_lhs : bool, optional
Indicates whether lhs or rhs matrix is sparse. Default value is False.
Expand All @@ -2028,10 +2037,14 @@ def sparse_dense(data, weight, sparse_lhs=False):
result: tvm.relay.Expr
The computed result.
"""
if hasattr(weight, "indices"):
return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr, sparse_lhs)
if hasattr(sparse_mat, "indices"):
return _make.sparse_dense(
dense_mat, sparse_mat.data, sparse_mat.indices, sparse_mat.indptr, sparse_lhs
)
else:
return _make.sparse_dense(data, weight[0], weight[1], weight[2], sparse_lhs)
return _make.sparse_dense(
dense_mat, sparse_mat[0], sparse_mat[1], sparse_mat[2], sparse_lhs
)


def sparse_transpose(x):
Expand Down

0 comments on commit a466aed

Please sign in to comment.