diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py index 46be5d32..a6edd7dc 100644 --- a/mesh_tensorflow/transformer/utils.py +++ b/mesh_tensorflow/transformer/utils.py @@ -1697,7 +1697,8 @@ def get_estimator(model_type, vocabulary, mesh_shape, tpu_config=my_tpu_config, session_config=session_config, save_checkpoints_steps=save_checkpoints_steps, - save_checkpoints_secs=None) + save_checkpoints_secs=None, + save_checkpoint_on_shutdown=False) transformer_model = build_model( model_type=model_type,