From d526be29d26af187ba5aaa348eccacbb5b135254 Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Thu, 7 Nov 2024 11:39:41 +0800 Subject: [PATCH] [Unified Checkpoint] Support empty state_dict saving (#9380) * fix empty state_dict * update sharding split_parma --- paddlenlp/trainer/training_args.py | 9 +++++++++ paddlenlp/trainer/unified_checkpoint/async_handler.py | 5 +++++ 2 files changed, 14 insertions(+) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 6b780f19417fe9..ed81f055bd80f1 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1164,6 +1164,15 @@ def split_parallel_config(parallel_config): raise ValueError( "If `enable_sharding_comm_overlap` in pipeline_parallel_configs, `amp_master_grad` must be True." ) + if ( + enable_sharding_comm_overlap + and self.unified_checkpoint + and "split_param" in split_parallel_config(self.sharding_parallel_config) + ): + logger.warning( + "Currently unified checkpoint do not support using `sharding_comm_overlap` and `split_param` at the same time, delete `sharding_comm_overlap`." + ) + enable_sharding_comm_overlap = False dygraph_pp_configs = { "delay_scale_loss": True if "enable_delay_scale_loss" in pipeline_parallel_config else False, diff --git a/paddlenlp/trainer/unified_checkpoint/async_handler.py b/paddlenlp/trainer/unified_checkpoint/async_handler.py index 4206821b50e532..942ea41508bfa4 100644 --- a/paddlenlp/trainer/unified_checkpoint/async_handler.py +++ b/paddlenlp/trainer/unified_checkpoint/async_handler.py @@ -77,6 +77,11 @@ def _file_save_async_or_sync( state_dict[k] = state_dict.pop(k).cpu().numpy() safe_save_file(state_dict, path, metadata={"format": "np"}) else: + if len(state_dict.keys()) == 0: + saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{self.global_rank}") + paddle.save(self.global_rank, saved_signal_path) + return + if state_dict_type == "model_weight": if self._shm_model_weight is None: self._meta_dict_model, buffer_size = create_meta_dict(state_dict)