diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py index e64ca8a2..7d78bd47 100644 --- a/mesh_tensorflow/transformer/utils.py +++ b/mesh_tensorflow/transformer/utils.py @@ -2398,18 +2398,24 @@ def eval_model(estimator, eval_dataset.postprocess_fn(d, example=ex) for d, ex in zip(outputs[:dataset_size], examples) ] - # Remove the used decodes. - del outputs[:dataset_size] global_step = int(get_step_from_checkpoint_path(checkpoint_path)) if output_eval_examples: + outputs_filename = os.path.join( + eval_summary_dir, + "{}_{}_outputs".format(eval_dataset.name, global_step), + ) + write_lines_to_file(outputs[:dataset_size], outputs_filename) predictions_filename = os.path.join( eval_summary_dir, "{}_{}_predictions".format(eval_dataset.name, global_step), ) write_lines_to_file(predictions, predictions_filename) + # Remove the used decodes. + del outputs[:dataset_size] + for metric_fn in eval_dataset.metric_fns: summary = tf.Summary() targets = cached_targets[eval_dataset.name]