Skip to content

Commit

Permalink
[Unified Checkpoint] update async save logic (#9274)
Browse files Browse the repository at this point in the history
* update async save signal

* fix async save hang
  • Loading branch information
DesmonDay authored Oct 16, 2024
1 parent b090f18 commit 697a4cc
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 34 deletions.
53 changes: 44 additions & 9 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,21 +140,25 @@ def __init__(self, args):
self._process_master_weight = None
self._process_optimizer_weight = None
self._lock = None
self._shared_save_path = None
self._shared_save_model_flag = None
self._shared_save_master_weight_flag = None
self._shared_save_optimizer_flag = None

if "async_save" in self.args.unified_checkpoint_config:
self._lock = multiprocessing.Lock()
self._shared_save_model_path = multiprocessing.Array("c", 100000)
self._shared_save_model_signal_path = multiprocessing.Array("c", 100000)
self._shared_save_master_weight_path = multiprocessing.Array("c", 100000)
self._shared_save_master_weight_signal_path = multiprocessing.Array("c", 100000)
self._shared_save_optimizer_path = multiprocessing.Array("c", 100000)
self._shared_save_optimizer_signal_path = multiprocessing.Array("c", 100000)
self._shared_save_model_flag = multiprocessing.Array("i", 1)
self._shared_save_master_weight_flag = multiprocessing.Array("i", 1)
self._shared_save_optimizer_flag = multiprocessing.Array("i", 1)

def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_type="model_weight"):
def _file_save_async_or_sync(
self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight"
):
if is_sync:
for k in list(state_dict.keys()):
if isinstance(state_dict[k], paddle.Tensor):
Expand All @@ -169,6 +173,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
meta_dict = self._meta_dict_model
shared_save_flag = self._shared_save_model_flag
shared_save_path = self._shared_save_model_path
shared_save_signal_path = self._shared_save_model_signal_path
if self._process_model_weight is None:
self._process_model_weight = multiprocessing.Process(
target=self._save_file_async_in_process,
Expand All @@ -177,12 +182,14 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
self._shm_model_weight.name,
self._shared_save_model_flag,
self._shared_save_model_path,
self._shared_save_model_signal_path,
self._lock,
state_dict_type,
self.global_rank,
),
)
self._process_model_weight.start()
process = self._process_model_weight
elif state_dict_type == "master_weight":
if self._shm_master_weight is None:
self._meta_dict_master_weight, buffer_size = create_meta_dict(state_dict)
Expand All @@ -191,6 +198,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
meta_dict = self._meta_dict_master_weight
shared_save_flag = self._shared_save_master_weight_flag
shared_save_path = self._shared_save_master_weight_path
shared_save_signal_path = self._shared_save_master_weight_signal_path
if self._process_master_weight is None:
self._process_master_weight = multiprocessing.Process(
target=self._save_file_async_in_process,
Expand All @@ -199,6 +207,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
self._shm_master_weight.name,
self._shared_save_master_weight_flag,
self._shared_save_master_weight_path,
self._shared_save_master_weight_signal_path,
self._lock,
"model_weight"
if "skip_save_model_weight" in self.args.unified_checkpoint_config
Expand All @@ -207,6 +216,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
),
)
self._process_master_weight.start()
process = self._process_master_weight
elif state_dict_type == "optimizer_weight":
if self._shm_optimizer_weight is None:
self._meta_dict_optim, buffer_size = create_meta_dict(state_dict)
Expand All @@ -215,6 +225,7 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
meta_dict = self._meta_dict_optim
shared_save_flag = self._shared_save_optimizer_flag
shared_save_path = self._shared_save_optimizer_path
shared_save_signal_path = self._shared_save_optimizer_signal_path
if self._process_optimizer_weight is None:
self._process_optimizer_weight = multiprocessing.Process(
target=self._save_file_async_in_process,
Expand All @@ -223,21 +234,26 @@ def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_ty
self._shm_optimizer_weight.name,
self._shared_save_optimizer_flag,
self._shared_save_optimizer_path,
self._shared_save_optimizer_signal_path,
self._lock,
state_dict_type,
self.global_rank,
),
)
self._process_optimizer_weight.start()
process = self._process_optimizer_weight

while True: # wait until no process is saving.
flag_value = shared_save_flag[0]
if flag_value == 0:
break
if not process.is_alive():
raise RuntimeError(f"The process that saves {state_dict_type} has been killed unexpectedly.")
time.sleep(0.5)
logger.info(f"Wait for the previous save process to finish saving {state_dict_type}")
# only save model weight or save master weight, we enter this loop.
self._reset_and_update(shared_save_path, path)
self._reset_and_update(shared_save_signal_path, signal_path)
_traverse_copy_to_shm(state_dict, meta_dict, shm_state_dict.buf)
with self._lock:
shared_save_flag[0] = 1
Expand All @@ -248,6 +264,7 @@ def _save_file_async_in_process(
shm_name,
shared_save_flag,
shared_save_path,
shared_save_signal_path,
lock,
state_dict_type,
global_rank,
Expand All @@ -261,11 +278,12 @@ def _save_file_async_in_process(
continue
if flag_value == 1: # need to save
path = shared_save_path[:].decode("utf-8").rstrip("\x00")
signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00")
logger.info(f"Start to async save {path}")
state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array
safe_save_file(state_dict, path, {"format": "np"})
del state_dict
saved_signal_path = os.path.join(os.path.dirname(path), f".{state_dict_type}.done.{global_rank}")
saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}")
paddle.save(global_rank, saved_signal_path)
with lock:
shared_save_flag[0] = 0
Expand All @@ -280,7 +298,7 @@ def _reset_and_update(self, shared_array, new_value):
encoded_value = new_value.encode("utf-8")
shared_array[: len(encoded_value)] = encoded_value

def save_unified_checkpoint(self, model, optimizer, output_dir):
def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None):
"""save unified checkpoint
Args:
Expand Down Expand Up @@ -317,6 +335,8 @@ def save_unified_checkpoint(self, model, optimizer, output_dir):

save_directory = output_dir
os.makedirs(save_directory, exist_ok=True)
if signal_dir is not None:
os.makedirs(signal_dir, exist_ok=True) # only for async save

# save model weights
if not skip_save_model_weight:
Expand All @@ -329,6 +349,7 @@ def save_unified_checkpoint(self, model, optimizer, output_dir):
self._file_save_async_or_sync(
state_dict,
path=os.path.join(save_directory, shard_file),
signal_path=signal_dir,
is_sync=is_sync_save,
state_dict_type="model_weight",
)
Expand Down Expand Up @@ -400,7 +421,7 @@ def load_unified_checkpoint(self, model, optimizer, resume_from_checkpoint: str)
if self.args.dataset_rank == 0 or self.args.use_expert_parallel:
load_unified_checkpoint_locally(self.args, model, resume_from_checkpoint, safe_serialization=True)

def save_non_merge_optimizer(self, model, optimizer, output_dir):
def save_non_merge_optimizer(self, model, optimizer, output_dir, signal_dir):
paddle.device.cuda.empty_cache()
optim_state_dict = nested_copy(optimizer.state_dict())
master_weights = None
Expand Down Expand Up @@ -459,12 +480,14 @@ def save_non_merge_optimizer(self, model, optimizer, output_dir):
self._file_save_async_or_sync(
optim_state_dict,
path=os.path.join(output_dir, optimizer_name),
signal_path=signal_dir,
is_sync=is_sync_save,
state_dict_type="optimizer_weight",
)
self._file_save_async_or_sync(
master_weights,
path=os.path.join(output_dir, master_weights_name),
signal_path=signal_dir,
is_sync=is_sync_save,
state_dict_type="master_weight",
)
Expand Down Expand Up @@ -514,22 +537,23 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint):

return returned_optim_state_dict

def save_unified_optimizer(self, model, optimizer, output_dir):
def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir):
"""save unified optimizer
Args:
model (PretrainedModel): model used to get key mapping.
optimizer (Optimizer): optimizer to save
output_dir (str): Save directory.
signal_dir (str): Asynchronous saving signal directory.
"""

if "ignore_merge_optimizer" in self.args.unified_checkpoint_config:
self.save_non_merge_optimizer(model, optimizer, output_dir)
self.save_non_merge_optimizer(model, optimizer, output_dir, signal_dir)
return

if paddle.distributed.get_world_size() <= 1:
self.save_single_card_optimizer(model, optimizer, output_dir)
self.save_single_card_optimizer(model, optimizer, output_dir) # no need to save signal
return

# Split into naive optimizer params and master weights.
Expand All @@ -545,20 +569,24 @@ def save_unified_optimizer(self, model, optimizer, output_dir):

save_directory = output_dir
os.makedirs(save_directory, exist_ok=True)
if signal_dir is not None:
os.makedirs(signal_dir, exist_ok=True)

is_sync_save = True
if "async_save" in self.args.unified_checkpoint_config:
is_sync_save = False
self._file_save_async_or_sync(
optim_state_dict,
path=os.path.join(save_directory, shard_optim_file),
signal_path=signal_dir,
is_sync=is_sync_save,
state_dict_type="optimizer_weight",
)
if master_weight_state_dict is not None:
self._file_save_async_or_sync(
master_weight_state_dict,
path=os.path.join(save_directory, shard_master_weight_file),
signal_path=signal_dir,
is_sync=is_sync_save,
state_dict_type="master_weight",
)
Expand Down Expand Up @@ -754,14 +782,20 @@ def unlink_shared_memory(self):

if self._shared_save_model_flag is not None:
while self._shared_save_model_flag[0] > 0: # async process is saving
if not self._process_model_weight.is_alive():
raise RuntimeError("The process that saves model_weight has been killed unexpectedly.")
time.sleep(0.5)
self._shared_save_model_flag[0] = -1
if self._shared_save_master_weight_flag is not None:
while self._shared_save_master_weight_flag[0] > 0:
if not self._process_master_weight.is_alive():
raise RuntimeError("The process that saves master_weight has been killed unexpectedly.")
time.sleep(0.5)
self._shared_save_master_weight_flag[0] = -1
if self._shared_save_optimizer_flag is not None:
while self._shared_save_optimizer_flag[0] > 0:
if not self._process_optimizer_weight.is_alive():
raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.")
time.sleep(0.5)
self._shared_save_optimizer_flag[0] = -1

Expand All @@ -778,7 +812,8 @@ def unlink_shared_memory(self):
self._shm_optimizer_weight.unlink()
self._shm_optimizer_weight = None

dist.barrier()
if paddle.distributed.get_world_size() > 1:
dist.barrier()


def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, safe_serialization=False):
Expand Down
Loading

0 comments on commit 697a4cc

Please sign in to comment.