Skip to content

Commit

Permalink
[Relay][TensorFlow] Remove 'input_0d_mismatch' special handling (#3087)
Browse files Browse the repository at this point in the history
* [Relay][TensorFlow] Remove 'input_0d_mismatch' special handling

* Add more tests.

* Cover the case that strided_slice outputs a scalar
  • Loading branch information
lixiaoquan authored and yzhliu committed Apr 26, 2019
1 parent 6a956fb commit 036294c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 29 deletions.
35 changes: 6 additions & 29 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def __call__(self, inputs, attrs, *args):
self._ignores.append('_node_name')
self._ignores.append('is_training')
self._ignores.append('_target_layout')
self._ignores.append('_input_0d_mismatch')

# apply custom check
if self._custom_check:
Expand Down Expand Up @@ -458,9 +457,9 @@ def _impl(inputs, attr, params):
def _expand_dims():
def _impl(inputs, attr, params):
dim_input = inputs.pop(1)
axis = params[dim_input.name_hint]
params.pop(dim_input.name_hint)
return _expand_dims_0d_aware(inputs[0], attr, axis=axis.asnumpy()[0])
axis = params.pop(_get_name_hint(dim_input)).asnumpy()[0]
return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
extras={'axis': int(axis), 'num_newaxis': 1})(inputs, attr)
return _impl

def _resize_bilinear():
Expand Down Expand Up @@ -528,7 +527,7 @@ def _impl(inputs, attr, params):
def _pack():
def _impl(inputs, attr, params):
axis = int(attr["axis"])
inputs_reshaped = [_expand_dims_0d_aware(i, attr, axis=axis, num_newaxis=1) for i in inputs]
inputs_reshaped = [_op.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs]
return _op.concatenate(inputs_reshaped, axis)
return _impl

Expand Down Expand Up @@ -820,9 +819,9 @@ def _transform_mask(stride_dim, ellipsis_mask):
pass
else:
final_output.append(out_shape[gather_index])
# Prevent 0-dim tensors which are not accepted by Relay

if not final_output:
final_output.append(1)
return out
return _op.reshape(out, newshape=tuple(final_output))
return _impl

Expand Down Expand Up @@ -984,16 +983,6 @@ def _impl(inputs, attr, params):
for split_item in splitted]), len(splitted))
return _impl

def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1):
if data in attr['_input_0d_mismatch']:
return data if num_newaxis == 1 else \
AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
extras={'axis': int(axis), 'num_newaxis': int(num_newaxis-1)})([data], attr)

return AttrCvt(op_name="expand_dims", ignores=['Tdim', 'N'],
extras={'axis': int(axis), 'num_newaxis': int(num_newaxis)})([data], attr)


def _softmax():
def _impl(inputs, attr, params):
return AttrCvt(op_name='softmax',
Expand Down Expand Up @@ -1647,7 +1636,6 @@ def __init__(self):
self._output_shapes = {}
self._num_param = 0
self._num_rnn_layer = False
self._outputs_are_0d = {}
self._input_shapes = {}
self._loops = {}
self._branches = {}
Expand Down Expand Up @@ -1737,7 +1725,6 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
# Operator name 'Const' is treated as a parameter to build params dict.

input_shapes = {}
input_0d_mismatch = set()
attr = self._parse_attr(node.attr)

# Variable converted to Const will not have only value attr
Expand All @@ -1753,10 +1740,6 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
# Will infer shapes if the graph is not frozen with add_shapes=True
self._output_shapes[node.name] = [None]

self._outputs_are_0d[node.name] = [ \
not shape if isinstance(tshape, list) else False \
for tshape in self._output_shapes[node.name]]

if node.op == "Const":
# All Const nodes are Param nodes, lets parse
self._num_param += 1
Expand Down Expand Up @@ -1810,14 +1793,8 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
input_shape = self._output_shapes[node_name][0]
inputs.append(in_sym[0])
input_shapes[in_sym[0]] = input_shape
# This means the node is 1d in Relay and 0d in TF.
# See `_expand_dims_0d_aware`.
if node_name in self._outputs_are_0d \
and self._outputs_are_0d[node_name][tensor_slot] and input_shape:
input_0d_mismatch.add(in_sym[0])

attr['_input_shapes'] = input_shapes
attr['_input_0d_mismatch'] = input_0d_mismatch

if node.op in _control_flow_nodes:
op = self._convert_control_flow_operator(node, inputs,
Expand Down
17 changes: 17 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ def _test_stridedslice(ip_shape, begin, end, stride, dtype,
def test_forward_stridedslice():
'''test StridedSlice'''

_test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1)
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32')
_test_stridedslice((3, 4, 3), [1, 0], [4, 3], [2, 1], 'float32', ellipsis_mask=8)
_test_stridedslice((3, 4, 3), [1, 0], [4, 2], [2, 1], 'float32', ellipsis_mask=2)
Expand Down Expand Up @@ -1475,6 +1476,21 @@ def test_forward_rel_ops():
_test_forward_rel_op([t1, t2], math_ops.equal)
_test_forward_rel_op([t1, t2], math_ops.not_equal)

#######################################################################
# ExpandDims
# ----------
def _test_forward_expand_dims(data, axis):
in1 = tf.placeholder(shape=data.shape, dtype=data.dtype, name='in1')
out = tf.expand_dims(in1, axis)
compare_tf_with_tvm([data], [in1.name], out.name)

def test_forward_expand_dims():
_test_forward_expand_dims(np.int32(1), 0)
_test_forward_expand_dims(np.array([1]), 0)
_test_forward_expand_dims(np.array([1]), -1)
_test_forward_expand_dims(np.array([[1], [2]]), 0)
_test_forward_expand_dims(np.array([[1], [2]]), 1)
_test_forward_expand_dims(np.array([[1], [2]]), -1)

#######################################################################
# Main
Expand Down Expand Up @@ -1509,6 +1525,7 @@ def test_forward_rel_ops():
test_forward_reverse_v2()
test_forward_pow_exp()
test_forward_sign()
test_forward_expand_dims()

# Reductions
test_forward_argminmax()
Expand Down

0 comments on commit 036294c

Please sign in to comment.