Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unified Checkpoint] update async save logic #9274

Merged
merged 2 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,21 +140,25 @@
self._process_master_weight = None
self._process_optimizer_weight = None
self._lock = None
self._shared_save_path = None
self._shared_save_model_flag = None
self._shared_save_master_weight_flag = None
self._shared_save_optimizer_flag = None

if "async_save" in self.args.unified_checkpoint_config:
self._lock = multiprocessing.Lock()
self._shared_save_model_path = multiprocessing.Array("c", 100000)
self._shared_save_model_signal_path = multiprocessing.Array("c", 100000)

Check warning on line 150 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L150

Added line #L150 was not covered by tests
self._shared_save_master_weight_path = multiprocessing.Array("c", 100000)
self._shared_save_master_weight_signal_path = multiprocessing.Array("c", 100000)

Check warning on line 152 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L152

Added line #L152 was not covered by tests
self._shared_save_optimizer_path = multiprocessing.Array("c", 100000)
self._shared_save_optimizer_signal_path = multiprocessing.Array("c", 100000)

Check warning on line 154 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L154

Added line #L154 was not covered by tests
self._shared_save_model_flag = multiprocessing.Array("i", 1)
self._shared_save_master_weight_flag = multiprocessing.Array("i", 1)
self._shared_save_optimizer_flag = multiprocessing.Array("i", 1)

def _file_save_async_or_sync(self, state_dict, path, is_sync=True, state_dict_type="model_weight"):
def _file_save_async_or_sync(
self, state_dict, path, signal_path=None, is_sync=True, state_dict_type="model_weight"
):
if is_sync:
for k in list(state_dict.keys()):
if isinstance(state_dict[k], paddle.Tensor):
Expand All @@ -169,6 +173,7 @@
meta_dict = self._meta_dict_model
shared_save_flag = self._shared_save_model_flag
shared_save_path = self._shared_save_model_path
shared_save_signal_path = self._shared_save_model_signal_path

Check warning on line 176 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L176

Added line #L176 was not covered by tests
if self._process_model_weight is None:
self._process_model_weight = multiprocessing.Process(
target=self._save_file_async_in_process,
Expand All @@ -177,12 +182,14 @@
self._shm_model_weight.name,
self._shared_save_model_flag,
self._shared_save_model_path,
self._shared_save_model_signal_path,
self._lock,
state_dict_type,
self.global_rank,
),
)
self._process_model_weight.start()
process = self._process_model_weight

Check warning on line 192 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L192

Added line #L192 was not covered by tests
elif state_dict_type == "master_weight":
if self._shm_master_weight is None:
self._meta_dict_master_weight, buffer_size = create_meta_dict(state_dict)
Expand All @@ -191,6 +198,7 @@
meta_dict = self._meta_dict_master_weight
shared_save_flag = self._shared_save_master_weight_flag
shared_save_path = self._shared_save_master_weight_path
shared_save_signal_path = self._shared_save_master_weight_signal_path

Check warning on line 201 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L201

Added line #L201 was not covered by tests
if self._process_master_weight is None:
self._process_master_weight = multiprocessing.Process(
target=self._save_file_async_in_process,
Expand All @@ -199,6 +207,7 @@
self._shm_master_weight.name,
self._shared_save_master_weight_flag,
self._shared_save_master_weight_path,
self._shared_save_master_weight_signal_path,
self._lock,
"model_weight"
if "skip_save_model_weight" in self.args.unified_checkpoint_config
Expand All @@ -207,6 +216,7 @@
),
)
self._process_master_weight.start()
process = self._process_master_weight

Check warning on line 219 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L219

Added line #L219 was not covered by tests
elif state_dict_type == "optimizer_weight":
if self._shm_optimizer_weight is None:
self._meta_dict_optim, buffer_size = create_meta_dict(state_dict)
Expand All @@ -215,6 +225,7 @@
meta_dict = self._meta_dict_optim
shared_save_flag = self._shared_save_optimizer_flag
shared_save_path = self._shared_save_optimizer_path
shared_save_signal_path = self._shared_save_optimizer_signal_path

Check warning on line 228 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L228

Added line #L228 was not covered by tests
if self._process_optimizer_weight is None:
self._process_optimizer_weight = multiprocessing.Process(
target=self._save_file_async_in_process,
Expand All @@ -223,21 +234,26 @@
self._shm_optimizer_weight.name,
self._shared_save_optimizer_flag,
self._shared_save_optimizer_path,
self._shared_save_optimizer_signal_path,
self._lock,
state_dict_type,
self.global_rank,
),
)
self._process_optimizer_weight.start()
process = self._process_optimizer_weight

Check warning on line 244 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L244

Added line #L244 was not covered by tests

while True: # wait until no process is saving.
flag_value = shared_save_flag[0]
if flag_value == 0:
break
if not process.is_alive():
raise RuntimeError(f"The process that saves {state_dict_type} has been killed unexpectedly.")

Check warning on line 251 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L250-L251

Added lines #L250 - L251 were not covered by tests
time.sleep(0.5)
logger.info(f"Wait for the previous save process to finish saving {state_dict_type}")
# only save model weight or save master weight, we enter this loop.
self._reset_and_update(shared_save_path, path)
self._reset_and_update(shared_save_signal_path, signal_path)

Check warning on line 256 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L256

Added line #L256 was not covered by tests
_traverse_copy_to_shm(state_dict, meta_dict, shm_state_dict.buf)
with self._lock:
shared_save_flag[0] = 1
Expand All @@ -248,6 +264,7 @@
shm_name,
shared_save_flag,
shared_save_path,
shared_save_signal_path,
lock,
state_dict_type,
global_rank,
Expand All @@ -261,11 +278,12 @@
continue
if flag_value == 1: # need to save
path = shared_save_path[:].decode("utf-8").rstrip("\x00")
signal_path = shared_save_signal_path[:].decode("utf-8").rstrip("\x00")

Check warning on line 281 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L281

Added line #L281 was not covered by tests
logger.info(f"Start to async save {path}")
state_dict = _read_state_dict_from_shm(meta_dict, shm) # numpy array
safe_save_file(state_dict, path, {"format": "np"})
del state_dict
saved_signal_path = os.path.join(os.path.dirname(path), f".{state_dict_type}.done.{global_rank}")
saved_signal_path = os.path.join(signal_path, f".{state_dict_type}.done.{global_rank}")

Check warning on line 286 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L286

Added line #L286 was not covered by tests
paddle.save(global_rank, saved_signal_path)
with lock:
shared_save_flag[0] = 0
Expand All @@ -280,7 +298,7 @@
encoded_value = new_value.encode("utf-8")
shared_array[: len(encoded_value)] = encoded_value

def save_unified_checkpoint(self, model, optimizer, output_dir):
def save_unified_checkpoint(self, model, optimizer, output_dir, signal_dir=None):
"""save unified checkpoint

Args:
Expand Down Expand Up @@ -317,6 +335,8 @@

save_directory = output_dir
os.makedirs(save_directory, exist_ok=True)
if signal_dir is not None:
os.makedirs(signal_dir, exist_ok=True) # only for async save

Check warning on line 339 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L338-L339

Added lines #L338 - L339 were not covered by tests

# save model weights
if not skip_save_model_weight:
Expand All @@ -329,6 +349,7 @@
self._file_save_async_or_sync(
state_dict,
path=os.path.join(save_directory, shard_file),
signal_path=signal_dir,
is_sync=is_sync_save,
state_dict_type="model_weight",
)
Expand Down Expand Up @@ -397,7 +418,7 @@
if self.args.dataset_rank == 0 or self.args.use_expert_parallel:
load_unified_checkpoint_locally(self.args, model, resume_from_checkpoint, safe_serialization=True)

def save_non_merge_optimizer(self, model, optimizer, output_dir):
def save_non_merge_optimizer(self, model, optimizer, output_dir, signal_dir):

Check warning on line 421 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L421

Added line #L421 was not covered by tests
paddle.device.cuda.empty_cache()
optim_state_dict = nested_copy(optimizer.state_dict())
master_weights = None
Expand Down Expand Up @@ -456,12 +477,14 @@
self._file_save_async_or_sync(
optim_state_dict,
path=os.path.join(output_dir, optimizer_name),
signal_path=signal_dir,

Check warning on line 480 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L480

Added line #L480 was not covered by tests
is_sync=is_sync_save,
state_dict_type="optimizer_weight",
)
self._file_save_async_or_sync(
master_weights,
path=os.path.join(output_dir, master_weights_name),
signal_path=signal_dir,

Check warning on line 487 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L487

Added line #L487 was not covered by tests
is_sync=is_sync_save,
state_dict_type="master_weight",
)
Expand Down Expand Up @@ -511,22 +534,23 @@

return returned_optim_state_dict

def save_unified_optimizer(self, model, optimizer, output_dir):
def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir):
"""save unified optimizer

Args:
model (PretrainedModel): model used to get key mapping.
optimizer (Optimizer): optimizer to save
output_dir (str): Save directory.
signal_dir (str): Asynchronous saving signal directory.

"""

if "ignore_merge_optimizer" in self.args.unified_checkpoint_config:
self.save_non_merge_optimizer(model, optimizer, output_dir)
self.save_non_merge_optimizer(model, optimizer, output_dir, signal_dir)
return

if paddle.distributed.get_world_size() <= 1:
self.save_single_card_optimizer(model, optimizer, output_dir)
self.save_single_card_optimizer(model, optimizer, output_dir) # no need to save signal

Check warning on line 553 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L553

Added line #L553 was not covered by tests
return

# Split into naive optimizer params and master weights.
Expand All @@ -542,20 +566,24 @@

save_directory = output_dir
os.makedirs(save_directory, exist_ok=True)
if signal_dir is not None:
os.makedirs(signal_dir, exist_ok=True)

Check warning on line 570 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L570

Added line #L570 was not covered by tests

is_sync_save = True
if "async_save" in self.args.unified_checkpoint_config:
is_sync_save = False
self._file_save_async_or_sync(
optim_state_dict,
path=os.path.join(save_directory, shard_optim_file),
signal_path=signal_dir,

Check warning on line 578 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L578

Added line #L578 was not covered by tests
is_sync=is_sync_save,
state_dict_type="optimizer_weight",
)
if master_weight_state_dict is not None:
self._file_save_async_or_sync(
master_weight_state_dict,
path=os.path.join(save_directory, shard_master_weight_file),
signal_path=signal_dir,

Check warning on line 586 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L586

Added line #L586 was not covered by tests
is_sync=is_sync_save,
state_dict_type="master_weight",
)
Expand Down Expand Up @@ -747,14 +775,20 @@

if self._shared_save_model_flag is not None:
while self._shared_save_model_flag[0] > 0: # async process is saving
if not self._process_model_weight.is_alive():
raise RuntimeError("The process that saves model_weight has been killed unexpectedly.")
time.sleep(0.5)
self._shared_save_model_flag[0] = -1
if self._shared_save_master_weight_flag is not None:
while self._shared_save_master_weight_flag[0] > 0:
if not self._process_master_weight.is_alive():
raise RuntimeError("The process that saves master_weight has been killed unexpectedly.")

Check warning on line 785 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L784-L785

Added lines #L784 - L785 were not covered by tests
time.sleep(0.5)
self._shared_save_master_weight_flag[0] = -1
if self._shared_save_optimizer_flag is not None:
while self._shared_save_optimizer_flag[0] > 0:
if not self._process_optimizer_weight.is_alive():
raise RuntimeError("The process that saves optimizer_weight has been killed unexpectedly.")

Check warning on line 791 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L790-L791

Added lines #L790 - L791 were not covered by tests
time.sleep(0.5)
self._shared_save_optimizer_flag[0] = -1

Expand All @@ -771,7 +805,8 @@
self._shm_optimizer_weight.unlink()
self._shm_optimizer_weight = None

dist.barrier()
if paddle.distributed.get_world_size() > 1:
dist.barrier()

Check warning on line 809 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L808-L809

Added lines #L808 - L809 were not covered by tests


def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, safe_serialization=False):
Expand Down
Loading
Loading