Skip to content

Commit

Permalink
[Bugfix][TF] reset graph after getting tag of savedmodel (apache#4055)
Browse files Browse the repository at this point in the history
@zhiics @icemelon9
  • Loading branch information
yongwww authored and wweic committed Oct 18, 2019
1 parent 1b8fff6 commit f240ef2
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions python/tvm/relay/frontend/tensorflow_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,16 @@ def _get_output_names(self):
"InputConfiguration: Unable to import tensorflow which is "
"required to restore from saved model.")
tags = self._get_tag_set()
output_names = set()
with tf.Session() as sess:
meta_graph_def = tf.saved_model.loader.load(sess,
tags,
self._model_dir)
output_names = set()
for sig_def in meta_graph_def.signature_def.values():
for output_tensor in sig_def.outputs.values():
output_names.add(output_tensor.name.replace(":0", ""))
return ",".join(output_names)
tf.reset_default_graph()
return ",".join(output_names)

def _load_saved_model(self):
"""Load the tensorflow saved model."""
Expand Down

0 comments on commit f240ef2

Please sign in to comment.