diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 10f23a49b5de..47aca3816e6f 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -646,6 +646,9 @@ def _transform_mask(stride_dim, ellipsis_mask): pass else: final_output.append(out_shape[gather_index]) + # Prevent 0-dim tensors which are not accepted by nnvm + if not final_output: + final_output.append(1) return _sym.reshape(out, shape=tuple(final_output)) return _impl @@ -1187,8 +1190,8 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): raise NotImplementedError( \ "Please freeze the graph with add_shapes=True") self._outputs_are_0d[node.name] = [ \ - not shape if isinstance(shape, list) else False \ - for shape in self._output_shapes[node.name]] + not tshape if isinstance(tshape, list) else False \ + for tshape in self._output_shapes[node.name]] if node.op == "Placeholder": self._nodes[node.name] = _sym.Variable(name=node.name, diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index ed3d0272b4fc..5b8f11695790 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -463,6 +463,7 @@ def test_forward_stridedslice(): _test_stridedslice((3, 4, 5, 4, 5, 6), [1, 2, 0, -3], [4, 5, 3, 3], [2, 2, 1, 1], 'float32', shrink_axis_mask=8, new_axis_mask=1, ellipsis_mask=2, begin_mask=5, end_mask=8) + _test_stridedslice((1), [0], [1], [1], 'float32', shrink_axis_mask=1) #######################################################################