From d526be29d26af187ba5aaa348eccacbb5b135254 Mon Sep 17 00:00:00 2001
From: Siming Dai <908660116@qq.com>
Date: Thu, 7 Nov 2024 11:39:41 +0800
Subject: [PATCH] [Unified Checkpoint] Support empty state_dict saving (#9380)

* fix empty state_dict

* update sharding split_parma
---
 paddlenlp/trainer/training_args.py                    | 9 +++++++++
 paddlenlp/trainer/unified_checkpoint/async_handler.py | 5 +++++
 2 files changed, 14 insertions(+)

diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py
index 6b780f19417fe9..ed81f055bd80f1 100644
--- a/paddlenlp/trainer/training_args.py
+++ b/paddlenlp/trainer/training_args.py
@@ -1164,6 +1164,15 @@ def split_parallel_config(parallel_config):
                         raise ValueError(
                             "If `enable_sharding_comm_overlap` in pipeline_parallel_configs, `amp_master_grad` must be True."
                         )
+                    if (
+                        enable_sharding_comm_overlap
+                        and self.unified_checkpoint
+                        and "split_param" in split_parallel_config(self.sharding_parallel_config)
+                    ):
+                        logger.warning(
+                            "Currently unified checkpoint do not support using `sharding_comm_overlap` and `split_param` at the same time, delete `sharding_comm_overlap`."
+                        )
+                        enable_sharding_comm_overlap = False
 
                     dygraph_pp_configs = {
                         "delay_scale_loss": True if "enable_delay_scale_loss" in pipeline_parallel_config else False,
diff --git a/paddlenlp/trainer/unified_checkpoint/async_handler.py b/paddlenlp/trainer/unified_checkpoint/async_handler.py
index 4206821b50e532..942ea41508bfa4 100644
--- a/paddlenlp/trainer/unified_checkpoint/async_handler.py
+++ b/paddlenlp/trainer/unified_checkpoint/async_handler.py
@@ -77,6 +77,11 @@ def _file_save_async_or_sync(
                     state_dict[k] = state_dict.pop(k).cpu().numpy()
             safe_save_file(state_dict, path, metadata={"format": "np"})
         else:
+            if len(state_dict.keys()) == 0:
+                saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{self.global_rank}")
+                paddle.save(self.global_rank, saved_signal_path)
+                return
+
             if state_dict_type == "model_weight":
                 if self._shm_model_weight is None:
                     self._meta_dict_model, buffer_size = create_meta_dict(state_dict)