Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TENSORFLOW]Sparse2Dense support #5767

Merged
merged 2 commits into from
Nov 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,18 @@ def _impl(inputs, attr, params, mod):
return _impl


def _sparse_to_dense():
def _impl(inputs, attr, params, mod):
sparse_indices = inputs[0]
sparse_values = inputs[2]
default_value = inputs[3]
output_shape = attr["_output_shapes"][0]

return _op.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value)

return _impl


def _bias_add():
def _impl(inputs, attr, params, mod):
# Must expand for proper broadcasting in NCHW.
Expand Down Expand Up @@ -2383,6 +2395,7 @@ def _impl(inputs, attr, params, mod):
"Softplus": _softplus(),
"SpaceToBatchND": _space_to_batch_nd(),
"SpaceToDepth": _space_to_depth(),
"SparseToDense": _sparse_to_dense(),
"Split": _split(False),
"SplitV": _split(True),
"Sqrt": AttrCvt("sqrt"),
Expand Down
77 changes: 77 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3936,6 +3936,83 @@ def test_forward_dilation():
_test_dilation2d([1, 3, 3, 1], [2, 2, 1], [1, 1, 1, 1], [1, 1, 2, 1], "VALID")


#######################################################################
# Sparse To Dense
# ---------------
def _test_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape):
with tf.Graph().as_default():
indices = tf.placeholder(
shape=sparse_indices.shape, dtype=str(sparse_indices.dtype), name="indices"
)
values = tf.placeholder(
shape=sparse_values.shape, dtype=str(sparse_values.dtype), name="values"
)
oshape = tf.constant(output_shape, shape=output_shape.shape, dtype=str(output_shape.dtype))

if default_value == None:
output = tf.sparse_to_dense(indices, oshape, values)
compare_tf_with_tvm(
[sparse_indices, sparse_values], ["indices:0", "values:0"], output.name
)
else:
dv = tf.placeholder(shape=(), dtype=str(default_value.dtype), name="default_value")
output = tf.sparse_to_dense(indices, oshape, values, dv)
compare_tf_with_tvm(
[sparse_indices, sparse_values, default_value],
["indices:0", "values:0", "default_value:0"],
output.name,
)


def test_forward_sparse_to_dense():
# scalar
_test_sparse_to_dense(
sparse_indices=np.int32(1),
sparse_values=np.int32(3),
default_value=np.int32(0),
output_shape=np.array([5]).astype("int32"),
)

# vector
_test_sparse_to_dense(
sparse_indices=np.array([0, 1, 4]).astype("int32"),
sparse_values=np.array([3, 3, 3]).astype("int32"),
default_value=np.int32(0),
output_shape=np.array([5]).astype("int32"),
)

# vector nXd
_test_sparse_to_dense(
sparse_indices=np.array([[0, 0], [1, 2]]).astype("int32"),
sparse_values=np.array([1, 2]).astype("int32"),
default_value=np.int32(0),
output_shape=np.array([3, 4]).astype("int32"),
)

_test_sparse_to_dense(
sparse_indices=np.array([[0, 0, 0], [1, 2, 3]]).astype("int32"),
sparse_values=np.array([1, 2]).astype("int32"),
default_value=np.int32(4),
output_shape=np.array([2, 3, 4]).astype("int32"),
)

# floats
_test_sparse_to_dense(
sparse_indices=np.array([0, 1, 4]).astype("int32"),
sparse_values=np.array([3.1, 3.1, 3.1]).astype("float32"),
default_value=np.float32(3.5),
output_shape=np.array([5]).astype("int32"),
)

# default value not specified
_test_sparse_to_dense(
sparse_indices=np.array([0, 1, 4]).astype("int32"),
sparse_values=np.array([3.1, 3.1, 3.1]).astype("float32"),
default_value=None,
output_shape=np.array([5]).astype("int32"),
)


#######################################################################
# infinity ops
# ------------
Expand Down