Skip to content

Commit

Permalink
TF frontend support added
Browse files Browse the repository at this point in the history
  • Loading branch information
ANSHUMAN TRIPATHY authored and ANSHUMAN TRIPATHY committed Feb 17, 2021
1 parent 83b5d2e commit d7790fa
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 11 deletions.
33 changes: 33 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,38 @@ def _impl(inputs, attr, params, mod):
return _impl


def _sparse_tensor_dense_add():
# Sparse utility from scipy
from scipy.sparse import csr_matrix

def _impl(inputs, attr, params, mod):
assert len(inputs) == 4, "There should be 4 input tensors"

indices_tensor = _infer_value(inputs[0], params, mod).asnumpy()
values_tensor = _infer_value(inputs[1], params, mod).asnumpy()
dense_shape_tensor = _infer_value(inputs[2], params, mod).asnumpy()

data = inputs[3]

rows = [x[0] for x in indices_tensor]
cols = [x[1] for x in indices_tensor]

# Create scipy sparse Tensor(CSR)
weight_sp = csr_matrix(
(values_tensor, (rows, cols)), shape=tuple(dense_shape_tensor.tolist())
)

weight_data = _expr.const(weight_sp.data, weight_sp.data.dtype)
weight_indptrs = _expr.const(weight_sp.indptr, weight_sp.indptr.dtype)
weight_indices = _expr.const(weight_sp.indices, weight_sp.indices.dtype)

ret = _op.nn.sparse_add(data, [weight_data, weight_indices, weight_indptrs])

return ret

return _impl


def _identity():
def _impl(inputs, attr, params, mod):
return inputs[0]
Expand Down Expand Up @@ -2478,6 +2510,7 @@ def _impl(inputs, attr, params, mod):
"SparseToDense": _sparse_to_dense(),
"SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(),
"SparseFillEmptyRows": _sparse_fill_empty_rows(),
"SparseTensorDenseAdd": _sparse_tensor_dense_add(),
"Split": _split(False),
"SplitV": _split(True),
"Sqrt": AttrCvt("sqrt"),
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ def alter_op_layout_sparse_dense(attrs, inputs, tinfos, out_type):
return topi.nn.sparse_dense_alter_layout(attrs, inputs, tinfos, out_type)


# sparse_add
reg.register_strategy("nn.sparse_add", strategy.sparse_add_strategy)
reg.register_pattern("nn.sparse_add", reg.OpPattern.ELEMWISE)


@reg.register_compute("nn.internal.sparse_dense_padded")
def compute_sparse_dense_padded(attrs, inputs, out_type):
"""Compute definition of sparse_dense_padded"""
Expand Down
33 changes: 33 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2148,6 +2148,39 @@ def sparse_transpose(x):
return expr.TupleWrapper(_make.sparse_transpose(x[0], x[1], x[2]), 3)


# pylint: disable=no-else-return,inconsistent-return-statements
def sparse_add(dense_mat, sparse_mat):
r"""
Computes the matrix addition of `dense_mat` and `sparse_mat`, where `dense_mat` is
a dense matrix and `sparse_mat` is a sparse (either BSR or CSR) namedtuple with
fields `data`, `indices`, and `indptr`.
.. math::
\mbox{sparse_add}(dense_mat, sparse_mat)[m, n] = \mbox{add}(\mbox{as_dense}(S), (D))[m, n]
where `as_dense` returns dense equivalent of the given S(sparse matrix)
while performing addition with given D(dense matrix).
Parameters
----------
dense_mat : tvm.relay.Expr
The input dense matrix for the matrix multiplication
sparse_mat : Union[namedtuple, Tuple[ndarray, ndarray, ndarray]].
The input sparse matrix for the matrix multiplication.
Returns
-------
result: tvm.relay.Expr
The computed result.
"""
if hasattr(sparse_mat, "indices"):
return _make.sparse_add(dense_mat, sparse_mat.data, sparse_mat.indices, sparse_mat.indptr)
else:
return _make.sparse_add(dense_mat, sparse_mat[0], sparse_mat[1], sparse_mat[2])


def contrib_conv2d_winograd_without_weight_transform(
data,
weight,
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/nn/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def sparse_add(dense_data, sparse_data, sparse_indices, sparse_indptr):
Parameters
----------
dense_data : tvm.te.Tensor
2-D with shape [M, K], float32
2-D with shape [M, N], float32
sparse_data : tvm.te.Tensor
1-D with shape [nnz] (CSR) or
Expand All @@ -374,7 +374,7 @@ def sparse_add(dense_data, sparse_data, sparse_indices, sparse_indptr):
1-D with shape [nnz] (CSR) or
sparse_indptr : tvm.te.Tensor
1-D with shape [N + 1] (CSR) or
1-D with shape [M + 1] (CSR) or
Returns
-------
Expand Down
10 changes: 4 additions & 6 deletions src/relay/op/nn/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ bool SparseAddRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
ICHECK_EQ(sparse_data->shape.size(), 1);
const auto* sparse_indices = types[2].as<TensorTypeNode>();
ICHECK_EQ(sparse_indices->shape.size(), 1);
const auto* sparse_indptr = types[3].as<TensorTypeNode>();

reporter->Assign(types[4], TensorType(dense_data->shape, dense_data->dtype));
return true;
Expand All @@ -219,15 +218,14 @@ Expr MakeSparseAdd(Expr dense_data, Expr sparse_data, Expr sparse_indices, Expr
TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_add").set_body_typed(MakeSparseAdd);

RELAY_REGISTER_OP("nn.sparse_add")
.describe(R"code(Add a dense matrix X with sparse matrix Y. Only support square sparse matrix
.describe(R"code(Add a dense matrix X with sparse matrix Y.
- **dense**: `(N, N)`
- **sparse**: `(N, N)`
- **dense**: `(M, N)`
- **sparse**: `(M, N)`
- **out**: `(N, N)`.
- **out**: `(M, N)`.
)code" TVM_ADD_FILELINE)
.set_attrs_type<SparseTransposeAttrs>()
.set_num_inputs(4)
.add_argument("dense_data", "2D Tensor", "Dense data matrix.")
.add_argument("sparse_data", "1D Tensor", "Sparse data matrix.")
Expand Down
39 changes: 36 additions & 3 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1899,6 +1899,39 @@ def test_forward_sparse_fill_empty_rows(
sparse_indices_np, sparse_values_np, dense_shape_np, default_value_int, use_dyn
):
""" sparse_fill_empty_rows op test"""
_test_sparse_fill_empty_rows(
sparse_indices_np, sparse_values_np, dense_shape_np, default_value_int, use_dyn
)


#######################################################################
# tensorflow.sparse.add
# ----------------------------------


def _test_sparse_add(indices, values, A_shape, B_shape, dtype, flip=False):
""" One iteration of tf.sparse.add """

# TODO(ANSHUMAN87): support cuda
# TODO(ANSHUMAN87): support flip case
# TODO(ANSHUMAN87): support both sparse input case

with tf.Graph().as_default():
A_sp = tf.sparse.SparseTensor(indices=indices, values=values, dense_shape=A_shape)
B = tf.placeholder(shape=B_shape, dtype=dtype, name="B")

if flip:
result = tf.sparse.add(B, A_sp, threshold=0)
else:
result = tf.sparse.add(A_sp, B, threshold=0)

B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)

compare_tf_with_tvm([B_np], [B.name], result.name, no_gpu=True)


def test_sparse_add():
""" sparse.add op test"""
###################################################################
#
# In order to create a SparseTensor, it requires 3 input as below:
Expand All @@ -1910,9 +1943,9 @@ def test_forward_sparse_fill_empty_rows(
# [0, 0, 0, 0]]
#
# ------------------------------------------------------------------
_test_sparse_fill_empty_rows(
sparse_indices_np, sparse_values_np, dense_shape_np, default_value_int, use_dyn
)

# TODO(ANSHUMAN87): add more test case
_test_sparse_add([[0, 0], [1, 2]], [4.0, 8.0], [3, 4], [3, 4], "float32")


#######################################################################
Expand Down

0 comments on commit d7790fa

Please sign in to comment.