Skip to content

Commit

Permalink
Merge pull request #1000 from jignparm/jignparm/servetag
Browse files Browse the repository at this point in the history
Make 'serve' tag default again.
  • Loading branch information
jignparm authored Jul 6, 2020
2 parents 011cf5e + 7498113 commit 4381b2b
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions tf2onnx/tf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,16 @@ def from_checkpoint(model_path, input_names, output_names):
def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signatures):
"""Load tensorflow graph from saved_model."""

wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve"
wrn_empty_tag = "'--tag' value is empty string. Using tag =[[]]"

if tag is None:
tag = [tf.saved_model.tag_constants.SERVING]
logger.warning(wrn_no_tag)

if tag == '':
tag = [[]]
logger.warning(wrn_empty_tag)

if not isinstance(tag, list):
tag = [tag]
Expand Down Expand Up @@ -218,7 +226,8 @@ def _from_saved_model_v1(sess, model_path, input_names, output_names, tag, signa
def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_def, concrete_function_index):
"""Load tensorflow graph from saved_model."""

wrn_no_tag = "'--tag' not specified for saved_model. Using empty tag [[]]"
wrn_no_tag = "'--tag' not specified for saved_model. Using --tag serve"
wrn_empty_tag = "'--tag' value is empty string. Using tag =[[]]"
wrn_sig_1 = "'--signature_def' not specified, using first signature: %s"
err_many_sig = "Cannot load multiple signature defs in TF2.x: %s"
err_no_call = "Model doesn't contain usable concrete functions under __call__. Try --signature-def instead."
Expand All @@ -227,8 +236,13 @@ def _from_saved_model_v2(model_path, input_names, output_names, tag, signature_d
err_sig_nomatch = "Specified signature not in model %s"

if tag is None:
tag = [[]]
tag = ['serve']
logger.warning(wrn_no_tag)

if tag == '':
tag = [[]]
logger.warning(wrn_empty_tag)

utils.make_sure(len(signature_def) < 2, err_many_sig, str(signature_def))
imported = tf.saved_model.load(model_path, tags=tag) # pylint: disable=no-value-for-parameter

Expand Down

0 comments on commit 4381b2b

Please sign in to comment.