From a91303ba99c8495cb2b98a05df3387c19c2872a3 Mon Sep 17 00:00:00 2001 From: LiangHao Date: Wed, 15 Jan 2020 22:03:58 +0800 Subject: [PATCH] [Relay][Frontend][TF] fix _parse_param bug (#4711) --- python/tvm/relay/frontend/tensorflow.py | 2 +- .../frontend/tensorflow/test_debugging.py | 18 +++++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index e7f4682e7eb2..408f88aa08d4 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -2391,7 +2391,7 @@ def _parse_param(self, key, value, name, shape): if np_array.dtype == np.dtype(object): # Object types are generally tensorflow DT_STRING (DecodeJpeg op). # Just leave it as placeholder. - if shape: + if shape and name in shape: var_shape = shape[name] else: var_shape = tensor_util.TensorShapeProtoToList(value.tensor.tensor_shape) diff --git a/tests/python/frontend/tensorflow/test_debugging.py b/tests/python/frontend/tensorflow/test_debugging.py index c7da636e28aa..d992e9a0d453 100644 --- a/tests/python/frontend/tensorflow/test_debugging.py +++ b/tests/python/frontend/tensorflow/test_debugging.py @@ -20,19 +20,22 @@ from tvm import relay from tvm.relay.frontend.tensorflow import from_tensorflow -def run_relay(graph, *vars): - mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) +def run_relay(graph, shape_dict=None, *vars): + mod, params = from_tensorflow( + graph.as_graph_def(add_shapes=True), + shape=shape_dict) ex = relay.create_executor('debug', mod=mod) return ex.evaluate()(*vars) def test_assert_true(): g = tf.Graph() + shape = (1, 2) with g.as_default(): - x = tf.placeholder(tf.float32, shape=()) - assert_op = tf.Assert(tf.less_equal(x, x), ["it failed"]) + x = tf.placeholder(tf.float32, shape=shape, name="input") + assert_op = tf.Assert(tf.reduce_all(tf.less_equal(x, x)), ["it failed"]) with tf.Session() as sess: - x_value = np.random.rand() + x_value = np.random.rand(*shape) assert sess.run(assert_op, feed_dict={x: x_value}) is None # In TVM, tf.assert is converted to a no-op which is actually a 0, @@ -44,7 +47,7 @@ def test_assert_true(): # do that, it's happening in Relay, and that optimization shouldn't # affect the arity of the main function. We should have to pass in # x_value here. - np.testing.assert_allclose(0, run_relay(g).asnumpy()) + np.testing.assert_allclose(0, run_relay(g, {'input':shape}).asnumpy()) def test_assert_true_var_capture(): g = tf.Graph() @@ -65,7 +68,8 @@ def test_assert_true_var_capture(): # the graph as a boolean, which is not correct - as you can see above, # TF believes that the value of this graph is None. In addition, the # arity of the translated function should be 1, not 2. - np.testing.assert_allclose(True, run_relay(g, x_value, x_value).asnumpy()) + np.testing.assert_allclose(True, + run_relay(g, None, x_value, x_value).asnumpy()) def test_assert_false(): g = tf.Graph()