From a2cd47e59fbd525d87f55012f2f7788ab9222915 Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Wed, 10 Jun 2020 16:40:38 +0530 Subject: [PATCH] [TENSORFLOW]Sparse2Dense support --- python/tvm/relay/frontend/tensorflow.py | 13 +++++ .../frontend/tensorflow/test_forward.py | 58 +++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 1f5786f911cb..e61857cf9ad9 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1281,6 +1281,18 @@ def _impl(inputs, attr, params, mod): return _impl +def _sparse_to_dense(): + def _impl(inputs, attr, params, mod): + sparse_indices = inputs[0] + sparse_values = inputs[2] + default_value = inputs[3] + output_shape = attr['_output_shapes'][0] + + return _op.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value) + + return _impl + + def _bias_add(): def _impl(inputs, attr, params, mod): # Must expand for proper broadcasting in NCHW. @@ -2383,6 +2395,7 @@ def _impl(inputs, attr, params, mod): "Softplus": _softplus(), "SpaceToBatchND": _space_to_batch_nd(), "SpaceToDepth": _space_to_depth(), + 'SparseToDense': _sparse_to_dense(), "Split": _split(False), "SplitV": _split(True), "Sqrt": AttrCvt("sqrt"), diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 713993ea40a4..1228634fa1a6 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -3936,6 +3936,64 @@ def test_forward_dilation(): _test_dilation2d([1, 3, 3, 1], [2, 2, 1], [1, 1, 1, 1], [1, 1, 2, 1], "VALID") +####################################################################### +# Sparse To Dense +# --------------- +def _test_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape): + with tf.Graph().as_default(): + indices = tf.placeholder(shape=sparse_indices.shape, dtype=str(sparse_indices.dtype), name="indices") + values = tf.placeholder(shape=sparse_values.shape, dtype=str(sparse_values.dtype), name="values") + oshape = tf.constant(output_shape, shape=output_shape.shape, dtype=str(output_shape.dtype)) + + if default_value == None: + output = tf.sparse_to_dense(indices, oshape, values) + compare_tf_with_tvm([sparse_indices, sparse_values], + ["indices:0", "values:0"], + output.name) + else: + dv = tf.placeholder(shape=(), dtype=str(default_value.dtype), name="default_value") + output = tf.sparse_to_dense(indices, oshape, values, dv) + compare_tf_with_tvm([sparse_indices, sparse_values, default_value], + ["indices:0", "values:0", "default_value:0"], + output.name) + +def test_forward_sparse_to_dense(): + # scalar + _test_sparse_to_dense(sparse_indices=np.int32(1), + sparse_values=np.int32(3), + default_value=np.int32(0), + output_shape=np.array([5]).astype("int32")) + + # vector + _test_sparse_to_dense(sparse_indices=np.array([0, 1, 4]).astype("int32"), + sparse_values=np.array([3, 3, 3]).astype("int32"), + default_value=np.int32(0), + output_shape=np.array([5]).astype("int32")) + + # vector nXd + _test_sparse_to_dense(sparse_indices=np.array([[0, 0], [1, 2]]).astype("int32"), + sparse_values=np.array([1, 2]).astype("int32"), + default_value=np.int32(0), + output_shape=np.array([3, 4]).astype("int32")) + + _test_sparse_to_dense(sparse_indices=np.array([[0, 0, 0], [1, 2, 3]]).astype("int32"), + sparse_values=np.array([1, 2]).astype("int32"), + default_value=np.int32(4), + output_shape=np.array([2, 3, 4]).astype("int32")) + + # floats + _test_sparse_to_dense(sparse_indices=np.array([0, 1, 4]).astype("int32"), + sparse_values=np.array([3.1, 3.1, 3.1]).astype("float32"), + default_value=np.float32(3.5), + output_shape=np.array([5]).astype("int32")) + + # default value not specified + _test_sparse_to_dense(sparse_indices=np.array([0, 1, 4]).astype("int32"), + sparse_values=np.array([3.1, 3.1, 3.1]).astype("float32"), + default_value=None, + output_shape=np.array([5]).astype("int32")) + + ####################################################################### # infinity ops # ------------