Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip saving frozen parameters if using peft model with deepspeed #26503

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,22 @@
import weakref
from functools import partialmethod

from packaging import version

from ..dependency_versions_check import dep_version_check
from ..utils import is_accelerate_available, is_torch_available, logging
from ..utils import is_accelerate_available, is_peft_available, is_torch_available, logging


if is_torch_available():
import torch

from ..optimization import get_scheduler


if is_peft_available():
from peft import PeftModel


logger = logging.get_logger(__name__)


Expand All @@ -45,6 +52,9 @@ def is_deepspeed_available():
return False


if is_deepspeed_available():
from deepspeed import __version__ as deepspeed_version

if is_accelerate_available() and is_deepspeed_available():
from accelerate.utils.deepspeed import HfDeepSpeedConfig as DeepSpeedConfig
else:
Expand Down Expand Up @@ -398,9 +408,17 @@ def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path):

if len(deepspeed_checkpoint_dirs) > 0:
logger.info(f"Attempting to resume from {checkpoint_path}")

load_module_strict = True
if version.parse(deepspeed_version) > version.parse("0.10.0"):
if is_peft_available() and isinstance(deepspeed_engine.module, PeftModel):
load_module_strict = False
# this magically updates self.optimizer and self.lr_scheduler
load_path, _ = deepspeed_engine.load_checkpoint(
checkpoint_path, load_optimizer_states=True, load_lr_scheduler_states=True
checkpoint_path,
load_optimizer_states=True,
load_lr_scheduler_states=True,
load_module_strict=load_module_strict,
)
if load_path is None:
raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}")
Expand Down
31 changes: 27 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@

if is_deepspeed_available():
from accelerate.utils import DeepSpeedSchedulerWrapper
from deepspeed import __version__ as deepspeed_version


if TYPE_CHECKING:
Expand Down Expand Up @@ -2380,6 +2381,23 @@ def _load_rng_state(self, checkpoint):
if is_torch_tpu_available():
xm.set_rng_state(checkpoint_rng_state["xla"])

def _save_deepspeed_optim_and_model_states(self, output_dir):
# save both optimizer and model states.

# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
# config `stage3_gather_16bit_weights_on_model_save` is True
if version.parse(deepspeed_version) > version.parse("0.10.0") and is_peft_available():
# skip saving deepspeed frozen parameters if possible
self.model_wrapped.save_checkpoint(
output_dir, exclude_frozen_parameters=isinstance(self.model_wrapped.module, PeftModel)
)
else:
self.model_wrapped.save_checkpoint(output_dir)
if is_peft_available() and isinstance(self.model_wrapped.module, PeftModel):
logger.warning(
"Frozon model weights are also saved. If you want to skip saving them, please upgrade your deepspeed to at least 0.10.1"
)

def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
# want to save except FullyShardedDDP.
Expand All @@ -2395,9 +2413,13 @@ def _save_checkpoint(self, model, trial, metrics=None):
output_dir = os.path.join(run_dir, checkpoint_folder)
self.save_model(output_dir, _internal_call=True)
if self.is_deepspeed_enabled:
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
# config `stage3_gather_16bit_weights_on_model_save` is True
self.model_wrapped.save_checkpoint(output_dir)
not_stage3 = self.accelerator.deepspeed_config["zero_optimization"]["stage"] != 3
gather_16bit_weights = self.model_wrapped.zero_gather_16bit_weights_on_model_save()
if not_stage3 or gather_16bit_weights:
# We have already saved our deepspeed checkpoint when 'zero_gather_16bit_weights_on_model_save' is set to False
# in stage3. However, this is not the case in other stages or when 'zero_gather_16bit_weights_on_model_save'
# is set to True in stage3.
self._save_deepspeed_optim_and_model_states(output_dir)

# Save optimizer and scheduler
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
Expand Down Expand Up @@ -2895,7 +2917,8 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
self._save(output_dir, state_dict={})
# remove the dummy state_dict
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
self.model_wrapped.save_checkpoint(output_dir)
# both optimizer and model states are needed to restore model
self._save_deepspeed_optim_and_model_states(output_dir)

elif self.args.should_save:
self._save(output_dir)
Expand Down