Skip to content

Commit

Permalink
fix async save hang
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Oct 15, 2024
1 parent 2b62766 commit ddb768e
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def _file_save_async_or_sync(
),
)
self._process_model_weight.start()
process = self._process_model_weight

Check warning on line 192 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L192

Added line #L192 was not covered by tests
elif state_dict_type == "master_weight":
if self._shm_master_weight is None:
self._meta_dict_master_weight, buffer_size = create_meta_dict(state_dict)
Expand All @@ -215,6 +216,7 @@ def _file_save_async_or_sync(
),
)
self._process_master_weight.start()
process = self._process_master_weight

Check warning on line 219 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L219

Added line #L219 was not covered by tests
elif state_dict_type == "optimizer_weight":
if self._shm_optimizer_weight is None:
self._meta_dict_optim, buffer_size = create_meta_dict(state_dict)
Expand All @@ -239,11 +241,14 @@ def _file_save_async_or_sync(
),
)
self._process_optimizer_weight.start()
process = self._process_optimizer_weight

Check warning on line 244 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L244

Added line #L244 was not covered by tests

while True: # wait until no process is saving.
flag_value = shared_save_flag[0]
if flag_value == 0:
break
if not process.is_alive():
raise RuntimeError(f"The process that saves {state_dict_type} has been killed unexpectedly.")

Check warning on line 251 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L250-L251

Added lines #L250 - L251 were not covered by tests
time.sleep(0.5)
logger.info(f"Wait for the previous save process to finish saving {state_dict_type}")
# only save model weight or save master weight, we enter this loop.
Expand Down Expand Up @@ -771,14 +776,20 @@ def unlink_shared_memory(self):

if self._shared_save_model_flag is not None:
while self._shared_save_model_flag[0] > 0: # async process is saving
if not self._process_model_weight.is_alive():
raise RuntimeError("The process that saves model_weight has been killed unexpectedly.")
time.sleep(0.5)
self._shared_save_model_flag[0] = -1
if self._shared_save_master_weight_flag is not None:
while self._shared_save_master_weight_flag[0] > 0:
if not self._process_master_weight.is_alive():
raise RuntimeError("The process that saves master_weight has been killed unexpectedly.")

Check warning on line 786 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L785-L786

Added lines #L785 - L786 were not covered by tests
time.sleep(0.5)
self._shared_save_master_weight_flag[0] = -1
if self._shared_save_optimizer_flag is not None:
while self._shared_save_optimizer_flag[0] > 0:
if not self._process_optimizer_weight.is_alive():
raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.")

Check warning on line 792 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L791-L792

Added lines #L791 - L792 were not covered by tests
time.sleep(0.5)
self._shared_save_optimizer_flag[0] = -1

Expand All @@ -795,7 +806,8 @@ def unlink_shared_memory(self):
self._shm_optimizer_weight.unlink()
self._shm_optimizer_weight = None

dist.barrier()
if paddle.distributed.get_world_size() > 1:
dist.barrier()

Check warning on line 810 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L809-L810

Added lines #L809 - L810 were not covered by tests


def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, safe_serialization=False):
Expand Down

0 comments on commit ddb768e

Please sign in to comment.