Skip to content

Commit

Permalink
* review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 committed Mar 12, 2019
1 parent 686da01 commit e85905e
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1472,8 +1472,11 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):

# Infer shapes even without specifying "add_shapes=True"
if output_shapes == [None]:
out_type = ir_pass.infer_type(self._nodes[node.name][0])
self._output_shapes[node.name] = [get_const_tuple(out_type.checked_type.shape)]
out_shapes = []
for node_item in node_output:
out_type = ir_pass.infer_type(node_item)
out_shapes.append(get_const_tuple(out_type.checked_type.shape))
self._output_shapes[node.name] = out_shapes

if self._output_shapes[node.name] and shape and node.name in shape:
assert self._output_shapes[node.name] == list(shape[node.name])
Expand All @@ -1482,15 +1485,11 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
node_output = self._nodes[node.name]
if shape and (not self._output_shapes[node.name][0]
or -1 in self._output_shapes[node.name][0]):
if isinstance(node_output, _expr.TupleWrapper):
out_shapes = []
for tuple_item in node_output:
out_type = ir_pass.infer_type(tuple_item)
out_shapes.append(get_const_tuple(out_type.checked_type.shape))
self._output_shapes[node.name] = out_shapes
else:
out_type = ir_pass.infer_type(node_output[0])
self._output_shapes[node.name] = [get_const_tuple(out_type.checked_type.shape)]
out_shapes = []
for node_item in node_output:
out_type = ir_pass.infer_type(node_item)
out_shapes.append(get_const_tuple(out_type.checked_type.shape))
self._output_shapes[node.name] = out_shapes

out = []
if outputs is None:
Expand Down

0 comments on commit e85905e

Please sign in to comment.