diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3c9e4420124..c248b3e8abf 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2332,13 +2332,21 @@ def _save_checkpoint(self, model, trial, metrics=None): run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) - self.save_model(output_dir, _internal_call=True) + if os.path.exists(output_dir) and len(os.listdir(output_dir)) > 0: + logger.warning( + f"Checkpoint destination directory {output_dir} already exists and is non-empty." + "Saving will proceed but saved results may be invalid." + ) + staging_output_dir = output_dir + else: + staging_output_dir = os.path.join(run_dir, f"tmp-{checkpoint_folder}") + self.save_model(staging_output_dir, _internal_call=True) if not self.args.save_only_model: # Save optimizer and scheduler - self._save_optimizer_and_scheduler(output_dir) + self._save_optimizer_and_scheduler(staging_output_dir) # Save RNG state - self._save_rng_state(output_dir) + self._save_rng_state(staging_output_dir) # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: @@ -2358,10 +2366,14 @@ def _save_checkpoint(self, model, trial, metrics=None): # Save the Trainer state if self.args.should_save: - self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + self.state.save_to_json(os.path.join(staging_output_dir, TRAINER_STATE_NAME)) if self.args.push_to_hub: - self._push_from_checkpoint(output_dir) + self._push_from_checkpoint(staging_output_dir) + + # Place checkpoint in final location after all saving is finished. + if staging_output_dir != output_dir: + os.rename(staging_output_dir, output_dir) # Maybe delete some older checkpoints. if self.args.should_save: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 305ccb35d5b..129f40fc409 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -79,7 +79,8 @@ slow, torch_device, ) -from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend +from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, HPSearchBackend, get_last_checkpoint from transformers.training_args import OptimizerNames from transformers.utils import ( SAFE_WEIGHTS_INDEX_NAME, @@ -1310,6 +1311,19 @@ def test_save_checkpoints(self): trainer.train() self.check_saved_checkpoints(tmpdir, 5, int(self.n_epochs * 64 / self.batch_size), False) + def test_save_checkpoints_is_atomic(self): + class UnsaveableTokenizer(PreTrainedTokenizerBase): + def save_pretrained(self, *args, **kwargs): + raise OSError("simulated file write error") + + with tempfile.TemporaryDirectory() as tmpdir: + trainer = get_regression_trainer(output_dir=tmpdir, save_steps=5) + # Attach unsaveable tokenizer to partially fail checkpointing + trainer.tokenizer = UnsaveableTokenizer() + with self.assertRaises(OSError) as _context: + trainer.train() + assert get_last_checkpoint(tmpdir) is None + @require_safetensors def test_safe_checkpoints(self): for save_safetensors in [True, False]: