Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Oct 28, 2024
1 parent 2dd22ca commit fd5dea0
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 14 deletions.
4 changes: 2 additions & 2 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2308,7 +2308,7 @@ def save_model(
if output_dir is None:
output_dir = self.args.output_dir

if PREFIX_CHECKPOINT_DIR in output_dir and self.is_in_train:
if PREFIX_CHECKPOINT_DIR in os.path.split(output_dir)[-1]:
signal_dir = os.path.join(self.args.output_signal_dir, os.path.split(output_dir)[-1])
else:
signal_dir = self.args.output_signal_dir
Expand Down Expand Up @@ -2606,7 +2606,7 @@ def _save(
# signal_dir is used for asynchronous saving situations.
signal_dir = self.args.output_signal_dir
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
if PREFIX_CHECKPOINT_DIR in output_dir and self.is_in_train:
if PREFIX_CHECKPOINT_DIR in os.path.split(output_dir)[-1]:
signal_dir = os.path.join(signal_dir, os.path.split(output_dir)[-1])
os.makedirs(signal_dir, exist_ok=True)
logger.info(f"Saving model checkpoint finish signal to {signal_dir}")
Expand Down
13 changes: 1 addition & 12 deletions paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import copy
import json
import os
import sys

import paddle
from paddle.distributed import fleet
Expand All @@ -34,7 +33,7 @@
load_state_dict,
unwrap_model,
)
from paddlenlp.transformers.utils import dtype_byte_size, is_safetensors_available
from paddlenlp.transformers.utils import dtype_byte_size
from paddlenlp.utils.env import (
LORA_WEIGHTS_NAME,
PADDLE_MASTER_WEIGHTS_NAME,
Expand All @@ -53,12 +52,6 @@
from paddlenlp.utils.log import logger
from paddlenlp.utils.nested import nested_copy

if is_safetensors_available():
if sys.platform.startswith("win"):
from safetensors.numpy import load_file
else:
from paddlenlp.utils.safetensors import fast_load_file as load_file

from .async_handler import AsyncCheckpointHandler
from .check_completion import check_unified_checkpoint, check_unified_optimizer
from .load_dynamic import (
Expand Down Expand Up @@ -279,10 +272,6 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint):

model_state_dict = get_expected_state_dict(model)
struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} # get optimizer param mappings
optimizer_state_dict = load_file(optimizer_path)
if has_master_weights:
master_weights = load_file(master_weights_path)

optimizer_state_dict = load_state_dict(optimizer_path, None, None, device="expected")
if has_master_weights:
master_weights = load_state_dict(master_weights_path, None, None, device="expected")
Expand Down

0 comments on commit fd5dea0

Please sign in to comment.