Skip to content

Commit

Permalink
Add support to export a particular checkpoint. Fixes tensorflow#181.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 314207905
  • Loading branch information
sharannarang authored and Mesh TensorFlow Team committed Jun 1, 2020
1 parent b66c7d0 commit 94edc68
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion mesh_tensorflow/transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1918,6 +1918,7 @@ def run(tpu_job_name,
dataset_split="train",
autostack=True,
eval_checkpoint_step=None,
export_checkpoint_step=None,
export_path="",
mode="train",
iterations_per_loop=100,
Expand Down Expand Up @@ -1957,6 +1958,8 @@ def run(tpu_job_name,
autostack: boolean, see `get_estimator` docstring for details.
eval_checkpoint_step: int, list of ints, or None, see `eval_model` doc
string for details.
export_checkpoint_step: int or None, see `export_model` doc string for
details.
export_path: a string, path to export the saved model
mode: string, train/eval/perplexity_eval/infer
perplexity_eval computes the perplexity of the dev set.
Expand Down Expand Up @@ -2121,8 +2124,18 @@ def _input_fn(params, eval_dataset):
score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
model_dir, eval_checkpoint_step, dataset_split)
elif mode == "export":
if export_checkpoint_step:
checkpoint_path = get_checkpoint_iterator(
export_checkpoint_step, model_dir)
if isinstance(checkpoint_path, list):
checkpoint_path = checkpoint_path[0]
else:
checkpoint_path = next(checkpoint_path)
else:
# Use the latest checkpoint in the model directory.
checkpoint_path = None
export_model(estimator, export_path, vocabulary, sequence_length,
batch_size)
batch_size, checkpoint_path)

else:
raise ValueError(
Expand Down

0 comments on commit 94edc68

Please sign in to comment.