From 77a97605bd6665fbc3bd9047fcf9325c78bd3dc7 Mon Sep 17 00:00:00 2001 From: VeryLazyBoy Date: Sat, 30 Sep 2023 11:46:13 +0800 Subject: [PATCH 1/2] Skip saving frozon parameters if using peft model with deepspeed --- src/transformers/integrations/deepspeed.py | 21 +++++++++++++-- src/transformers/trainer.py | 31 +++++++++++++++++++--- 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index fb9c022b0f28f4..c44e2c73df62ec 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -19,9 +19,10 @@ import importlib.util import weakref from functools import partialmethod +from packaging import version from ..dependency_versions_check import dep_version_check -from ..utils import is_accelerate_available, is_torch_available, logging +from ..utils import is_accelerate_available, is_torch_available, is_peft_available, logging if is_torch_available(): @@ -29,6 +30,11 @@ from ..optimization import get_scheduler + +if is_peft_available(): + from peft import PeftModel + + logger = logging.get_logger(__name__) @@ -45,6 +51,9 @@ def is_deepspeed_available(): return False +if is_deepspeed_available(): + from deepspeed import __version__ as deepspeed_version + if is_accelerate_available() and is_deepspeed_available(): from accelerate.utils.deepspeed import HfDeepSpeedConfig as DeepSpeedConfig else: @@ -398,9 +407,17 @@ def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path): if len(deepspeed_checkpoint_dirs) > 0: logger.info(f"Attempting to resume from {checkpoint_path}") + + load_module_strict = True + if version.parse(deepspeed_version) > version.parse("0.10.0"): + if is_peft_available() and isinstance(deepspeed_engine.module, PeftModel): + load_module_strict = False # this magically updates self.optimizer and self.lr_scheduler load_path, _ = deepspeed_engine.load_checkpoint( - checkpoint_path, load_optimizer_states=True, load_lr_scheduler_states=True + checkpoint_path, + load_optimizer_states=True, + load_lr_scheduler_states=True, + load_module_strict=load_module_strict, ) if load_path is None: raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}") diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index b9e10376157457..c61bb11d6923f7 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -214,6 +214,7 @@ ) if is_deepspeed_available(): + from deepspeed import __version__ as deepspeed_version from accelerate.utils import DeepSpeedSchedulerWrapper @@ -2380,6 +2381,23 @@ def _load_rng_state(self, checkpoint): if is_torch_tpu_available(): xm.set_rng_state(checkpoint_rng_state["xla"]) + def _save_deepspeed_optim_and_model_states(self, output_dir): + # save both optimizer and model states. + + # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed + # config `stage3_gather_16bit_weights_on_model_save` is True + if version.parse(deepspeed_version) > version.parse("0.10.0") and is_peft_available(): + # skip saving deepspeed frozen parameters if possible + self.model_wrapped.save_checkpoint( + output_dir, exclude_frozen_parameters=isinstance(self.model_wrapped.module, PeftModel) + ) + else: + self.model_wrapped.save_checkpoint(output_dir) + if is_peft_available() and isinstance(self.model_wrapped.module, PeftModel): + logger.warning( + "Frozon model weights are also saved. If you want to skip saving them, please upgrade your deepspeed to at least 0.10.1" + ) + def _save_checkpoint(self, model, trial, metrics=None): # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we # want to save except FullyShardedDDP. @@ -2395,9 +2413,13 @@ def _save_checkpoint(self, model, trial, metrics=None): output_dir = os.path.join(run_dir, checkpoint_folder) self.save_model(output_dir, _internal_call=True) if self.is_deepspeed_enabled: - # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed - # config `stage3_gather_16bit_weights_on_model_save` is True - self.model_wrapped.save_checkpoint(output_dir) + not_stage3 = self.accelerator.deepspeed_config["zero_optimization"]["stage"] != 3 + gather_16bit_weights = self.model_wrapped.zero_gather_16bit_weights_on_model_save() + if not_stage3 or gather_16bit_weights: + # We have already saved our deepspeed checkpoint when 'zero_gather_16bit_weights_on_model_save' is set to False + # in stage3. However, this is not the case in other stages or when 'zero_gather_16bit_weights_on_model_save' + # is set to True in stage3. + self._save_deepspeed_optim_and_model_states(output_dir) # Save optimizer and scheduler if self.sharded_ddp == ShardedDDPOption.SIMPLE: @@ -2895,7 +2917,8 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa self._save(output_dir, state_dict={}) # remove the dummy state_dict remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) - self.model_wrapped.save_checkpoint(output_dir) + # both optimizer and model states are needed to restore model + self._save_deepspeed_optim_and_model_states(output_dir) elif self.args.should_save: self._save(output_dir) From 95d084166c8d9391198238937117c8dea681b32b Mon Sep 17 00:00:00 2001 From: VeryLazyBoy Date: Sat, 30 Sep 2023 12:14:00 +0800 Subject: [PATCH 2/2] Fix ruff errors --- src/transformers/integrations/deepspeed.py | 3 ++- src/transformers/trainer.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index c44e2c73df62ec..b6ea4e3e3a6525 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -19,10 +19,11 @@ import importlib.util import weakref from functools import partialmethod + from packaging import version from ..dependency_versions_check import dep_version_check -from ..utils import is_accelerate_available, is_torch_available, is_peft_available, logging +from ..utils import is_accelerate_available, is_peft_available, is_torch_available, logging if is_torch_available(): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c61bb11d6923f7..becadbec883449 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -214,8 +214,8 @@ ) if is_deepspeed_available(): - from deepspeed import __version__ as deepspeed_version from accelerate.utils import DeepSpeedSchedulerWrapper + from deepspeed import __version__ as deepspeed_version if TYPE_CHECKING: