diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index 6fc086b54584..81dec37b611e 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -687,7 +687,13 @@ def _save_checkpoint(self, model, metrics=None): # For ckpt integrity paddle.save(self.state.global_step, os.path.join(output_dir, ".checkpoint_done")) - def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_parallel=False): + def _save( + self, + output_dir: Optional[str] = None, + state_dict=None, + merge_tensor_parallel=False, + signal_dir: Optional[str] = None, + ): output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving model checkpoint to {output_dir}") diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 7c7ad82ccc11..ddc872ad6173 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -581,7 +581,9 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None): # Load potential model checkpoint if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: uc_async_save = self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config - resume_from_checkpoint = get_last_checkpoint(self.args.output_dir, uc_async_save) + resume_from_checkpoint = get_last_checkpoint( + self.args.output_dir, signal_folder=self.args.output_signal_dir, uc_async_save=uc_async_save + ) if resume_from_checkpoint is None: raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})") @@ -2509,7 +2511,7 @@ def _save_checkpoint(self, model, metrics=None): need_to_rotate_checkpoints = need_to_rotate_checkpoints and self.args.local_rank == 0 if need_to_rotate_checkpoints: self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) - self._rotate_checkpoints(use_mtime=False, output_dir=run_signal_dir) + self._rotate_checkpoints(use_mtime=True, output_dir=run_signal_dir) if strtobool(os.getenv("FLAG_LLM_PDC", "False")) and not ("async_save" in self.args.unified_checkpoint_config): # save checkpoint_done file to ensure checkpoint is complete