diff --git a/docs/source/developer_guides/low_level_api.md b/docs/source/developer_guides/low_level_api.md index 7490262abf3..ba827949383 100644 --- a/docs/source/developer_guides/low_level_api.md +++ b/docs/source/developer_guides/low_level_api.md @@ -25,6 +25,8 @@ Check the table below to see when you should inject adapters. | the model is modified inplace, keeping all the original attributes and methods | manually write the `from_pretrained` and `save_pretrained` utility functions from Hugging Face to save and load adapters | | works for any `torch` module and modality | doesn't work with any of the utility methods provided by `PeftModel` such as disabling and merging adapters | +## Creating a new PEFT model + To perform the adapter injection, use the [`inject_adapter_in_model`] method. This method takes 3 arguments, the PEFT config, the model, and an optional adapter name. You can also attach multiple adapters to the model if you call [`inject_adapter_in_model`] multiple times with different adapter names. For example, to inject LoRA adapters into the `linear` submodule of the `DummyModel` module: @@ -85,6 +87,8 @@ DummyModel( ) ``` +## Saving the model + To only save the adapter, use the [`get_peft_model_state_dict`] function: ```python @@ -95,3 +99,28 @@ print(peft_state_dict) ``` Otherwise, `model.state_dict()` returns the full state dict of the model. + +## Loading the model + +After loading the saved `state_dict`, it can be applied using the [`set_peft_model_state_dict`] function: + +```python +from peft import set_peft_model_state_dict + +model = DummyModel() +model = inject_adapter_in_model(lora_config, model) +outcome = set_peft_model_state_dict(model, peft_state_dict) +# check that there were no wrong keys +print(outcome.unexpected_keys) +``` + +If injecting the adapter is slow or you need to load a large number of adapters, you may use an optimization that allows to create an "empty" adapter on meta device and only fills the weights with real weights when the [`set_peft_model_state_dict`] is called. To do this, pass `low_cpu_mem_usage=True` to both [`inject_adapter_in_model`] and [`set_peft_model_state_dict`]. + +```python +model = DummyModel() +model = inject_adapter_in_model(lora_config, model, low_cpu_mem_usage=True) + +print(model.linear.lora_A["default"].weight.device.type == "meta") # should be True +set_peft_model_state_dict(model, peft_state_dict, low_cpu_mem_usage=True) +print(model.linear.lora_A["default"].weight.device.type == "cpu") # should be True +``` diff --git a/docs/source/developer_guides/troubleshooting.md b/docs/source/developer_guides/troubleshooting.md index 2baf9b6baac..c24ba2d57ea 100644 --- a/docs/source/developer_guides/troubleshooting.md +++ b/docs/source/developer_guides/troubleshooting.md @@ -250,6 +250,19 @@ TunerModelStatus( ) ``` +## Speed + +### Loading adapter weights is slow + +Loading adapters like LoRA weights should generally be fast compared to loading the base model. However, there can be use cases where the adapter weights are quite large or where users need to load a large number of adapters -- the loading time can add up in this case. The reason for this is that the adapter weights are first initialized and then overridden by the loaded weights, which is wasteful. To speed up the loading time, you can pass the `low_cpu_mem_usage=True` argument to [`~PeftModel.from_pretrained`] and [`~PeftModel.load_adapter`]. + + + +If this option works well across different use casese, it may become the default for adapter loading in the future. + + + + ## Reproducibility ### Models using batch norm diff --git a/src/peft/mapping.py b/src/peft/mapping.py index 5eb19dcb5b0..9c0ab95986c 100644 --- a/src/peft/mapping.py +++ b/src/peft/mapping.py @@ -196,7 +196,7 @@ def get_peft_model( def inject_adapter_in_model( - peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str = "default" + peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str = "default", low_cpu_mem_usage: bool = False ) -> torch.nn.Module: r""" A simple API to create and inject adapter in-place into a model. Currently the API does not support prompt learning @@ -210,6 +210,8 @@ def inject_adapter_in_model( The input model where the adapter will be injected. adapter_name (`str`, `optional`, defaults to `"default"`): The name of the adapter to be injected, if not provided, the default adapter name is used ("default"). + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. """ if peft_config.is_prompt_learning or peft_config.is_adaption_prompt: raise ValueError("`create_and_replace` does not support prompt learning and adaption prompt yet.") @@ -222,6 +224,6 @@ def inject_adapter_in_model( tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type] # By instantiating a peft model we are injecting randomly initialized LoRA layers into the model's modules. - peft_model = tuner_cls(model, peft_config, adapter_name=adapter_name) + peft_model = tuner_cls(model, peft_config, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) return peft_model.model diff --git a/src/peft/mixed_model.py b/src/peft/mixed_model.py index b0324b4dbfc..d9e231e018f 100644 --- a/src/peft/mixed_model.py +++ b/src/peft/mixed_model.py @@ -112,6 +112,8 @@ class PeftMixedModel(PushToHubMixin, torch.nn.Module): The config of the model to be tuned. The adapter type must be compatible. adapter_name (`str`, `optional`, defaults to `"default"`): The name of the first adapter. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. """ def __init__(self, model: nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None: @@ -219,12 +221,38 @@ def disable_adapter(self): finally: self.base_model.enable_adapter_layers() - def add_adapter(self, adapter_name: str, peft_config: PeftConfig): + def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None: + """ + Add an adapter to the model based on the passed configuration. + + This adapter is not trained. To load a trained adapter, check out [`PeftModel.load_adapter`]. + + The name for the new adapter should be unique. + + The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active + adapter. + + Args: + adapter_name (`str`): + The name of the adapter to be added. + peft_config ([`PeftConfig`]): + The configuration of the adapter to be added. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the process when loading saved + adapters. + + + + Don't use `low_cpu_mem_usage=True` when creating a new PEFT adapter for training (training is untested + and discouraged for PeftMixedModel in general). + + + """ _check_config_compatible(peft_config) try: self.peft_config[adapter_name] = peft_config - self.base_model.inject_adapter(self, adapter_name) + self.base_model.inject_adapter(self, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) except Exception: # something went wrong, roll back if adapter_name in self.peft_config: del self.peft_config[adapter_name] @@ -323,6 +351,37 @@ def _split_kwargs(cls, kwargs: dict[str, Any]): return PeftModel._split_kwargs(kwargs) def load_adapter(self, model_id: str, adapter_name: str, *args: Any, **kwargs: Any): + """ + Load a trained adapter into the model. + + The name for the new adapter should be unique. + + The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active + adapter. + + Args: + adapter_name (`str`): + The name of the adapter to be added. + peft_config ([`PeftConfig`]): + The configuration of the adapter to be added. + is_trainable (`bool`, *optional*, defaults to `False`): + Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be + used for inference. + torch_device (`str`, *optional*, defaults to None): + The device to load the adapter on. If `None`, the device will be inferred. + autocast_adapter_dtype (`bool`, *optional*, defaults to `True`): + Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter + weights using float16 and bfloat16 to float32, as this is typically required for stable training, and + only affect select PEFT tuners. + ephemeral_gpu_offload (`bool`, *optional*, defaults to `False`): + Whether to use ephemeral GPU offloading for partially loaded modules. Defaults to `False`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device before loading the saved weights. Useful to speed up the + process. + kwargs: (`optional`): + Additional arguments to modify the way the adapter is loaded, e.g. the token for Hugging Face Hub. + """ + # the low_cpu_mem_usage option is handled through kwargs output = PeftModel.load_adapter(self, model_id, adapter_name, *args, **kwargs) # TODO: not quite clear why this is necessary but tests fail without it self.set_adapter(self.active_adapters) @@ -373,6 +432,9 @@ def from_pretrained( The configuration object to use instead of an automatically loaded configuration. This configuration object is mutually exclusive with `model_id` and `kwargs`. This is useful when configuration is already loaded before calling `from_pretrained`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device before loading the saved weights. Useful to speed up the + process. kwargs: (`optional`): Additional keyword arguments passed along to the specific PEFT configuration class. """ @@ -412,5 +474,6 @@ def from_pretrained( # note: this is different from PeftModel.from_pretrained, we always return a PeftMixedModel model = cls(model, config, adapter_name) + # the low_cpu_mem_usage option is handled through kwargs model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs) return model diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 3eaff59529f..da269aad1ac 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -19,7 +19,7 @@ import inspect import os import warnings -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from copy import deepcopy from dataclasses import dataclass from typing import Any, Literal, Optional, Union @@ -27,7 +27,7 @@ import packaging.version import torch import transformers -from accelerate import dispatch_model, infer_auto_device_map +from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules from accelerate.utils import get_balanced_memory, named_module_tensors from huggingface_hub import HfFileSystem, ModelCard, ModelCardData, hf_hub_download @@ -119,6 +119,14 @@ class PeftModel(PushToHubMixin, torch.nn.Module): Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights using float16 and bfloat16 to float32, as this is typically required for stable training, and only affect select PEFT tuners. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading loading process. + + + + Don't use `low_cpu_mem_usage=True` when creating a new PEFT adapter for training. + + **Attributes**: - **base_model** ([`torch.nn.Module`]) -- The base transformer model used for Peft. @@ -141,6 +149,7 @@ def __init__( peft_config: PeftConfig, adapter_name: str = "default", autocast_adapter_dtype: bool = True, + low_cpu_mem_usage: bool = False, ) -> None: super().__init__() self.modules_to_save = None @@ -154,11 +163,13 @@ def __init__( if self._is_prompt_learning: self._peft_config = {adapter_name: peft_config} self.base_model = model - self.add_adapter(adapter_name, peft_config) + self.add_adapter(adapter_name, peft_config, low_cpu_mem_usage=low_cpu_mem_usage) else: self._peft_config = None cls = PEFT_TYPE_TO_MODEL_MAPPING[peft_config.peft_type] - self.base_model = cls(model, {adapter_name: peft_config}, adapter_name) + ctx = init_empty_weights if low_cpu_mem_usage else nullcontext + with ctx(): + self.base_model = cls(model, {adapter_name: peft_config}, adapter_name) self.set_additional_trainable_modules(peft_config, adapter_name) if hasattr(self.base_model, "_cast_adapter_dtype"): @@ -423,6 +434,7 @@ def from_pretrained( config: Optional[PeftConfig] = None, autocast_adapter_dtype: bool = True, ephemeral_gpu_offload: bool = False, + low_cpu_mem_usage: bool = False, **kwargs: Any, ) -> PeftModel: r""" @@ -457,6 +469,9 @@ def from_pretrained( are needed. Rather than perform expensive operations on small data, the data is transferred to the GPU on-demand, the operation(s) performed, and the results moved back to CPU memory. This brings a slight momentary VRAM overhead but gives orders of magnitude speedup in certain cases. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device before loading the saved weights. Useful to speed up the + process. torch_device (`str`, *optional*, defaults to None): The device to load the adapter on. If `None`, the device will be inferred. kwargs: (`optional`): @@ -553,14 +568,29 @@ def from_pretrained( raise ValueError("If model_id is a local path, then `adapters` must be passed in kwargs.") if config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys(): - model = cls(model, config, adapter_name, autocast_adapter_dtype=autocast_adapter_dtype) + model = cls( + model, + config, + adapter_name, + autocast_adapter_dtype=autocast_adapter_dtype, + low_cpu_mem_usage=low_cpu_mem_usage, + ) else: model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type]( - model, config, adapter_name, autocast_adapter_dtype=autocast_adapter_dtype + model, + config, + adapter_name, + autocast_adapter_dtype=autocast_adapter_dtype, + low_cpu_mem_usage=low_cpu_mem_usage, ) model.load_adapter( - model_id, adapter_name, is_trainable=is_trainable, autocast_adapter_dtype=autocast_adapter_dtype, **kwargs + model_id, + adapter_name, + is_trainable=is_trainable, + autocast_adapter_dtype=autocast_adapter_dtype, + low_cpu_mem_usage=low_cpu_mem_usage, + **kwargs, ) return model @@ -853,7 +883,7 @@ def get_base_model(self) -> torch.nn.Module: else self.base_model.model ) - def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None: + def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None: """ Add an adapter to the model based on the passed configuration. @@ -869,6 +899,10 @@ def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None: The name of the adapter to be added. peft_config ([`PeftConfig`]): The configuration of the adapter to be added. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the process when loading saved + adapters. Don't use this option when creating a new PEFT adapter for training. + """ if peft_config.peft_type != self.peft_type: raise ValueError( @@ -890,7 +924,9 @@ def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None: self.base_model.add_adapter(adapter_name, peft_config) else: self.peft_config[adapter_name] = peft_config - self.base_model.inject_adapter(self.base_model.model, adapter_name) + self.base_model.inject_adapter( + self.base_model.model, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage + ) except Exception: # something went wrong, roll back if adapter_name in self.peft_config: del self.peft_config[adapter_name] @@ -1078,6 +1114,7 @@ def load_adapter( torch_device: Optional[str] = None, autocast_adapter_dtype: bool = True, ephemeral_gpu_offload: bool = False, + low_cpu_mem_usage: bool = False, **kwargs: Any, ): """ @@ -1104,6 +1141,9 @@ def load_adapter( only affect select PEFT tuners. ephemeral_gpu_offload (`bool`, *optional*, defaults to `False`): Whether to use ephemeral GPU offloading for partially loaded modules. Defaults to `False`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device before loading the saved weights. Useful to speed up the + process. kwargs: (`optional`): Additional arguments to modify the way the adapter is loaded, e.g. the token for Hugging Face Hub. """ @@ -1129,14 +1169,18 @@ def load_adapter( raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.") else: peft_config.inference_mode = not is_trainable - self.add_adapter(adapter_name, peft_config) + self.add_adapter(adapter_name, peft_config, low_cpu_mem_usage=low_cpu_mem_usage) adapters_weights = load_peft_weights(model_id, device=torch_device, **hf_hub_download_kwargs) # load the weights into the model ignore_mismatched_sizes = kwargs.get("ignore_mismatched_sizes", False) load_result = set_peft_model_state_dict( - self, adapters_weights, adapter_name=adapter_name, ignore_mismatched_sizes=ignore_mismatched_sizes + self, + adapters_weights, + adapter_name=adapter_name, + ignore_mismatched_sizes=ignore_mismatched_sizes, + low_cpu_mem_usage=low_cpu_mem_usage, ) missing_keys, unexpected_keys = load_result.missing_keys, load_result.unexpected_keys tuner = self.peft_config[adapter_name].peft_type diff --git a/src/peft/tuners/adalora/model.py b/src/peft/tuners/adalora/model.py index 41262c95fa9..d85f4b8cdb9 100644 --- a/src/peft/tuners/adalora/model.py +++ b/src/peft/tuners/adalora/model.py @@ -42,6 +42,8 @@ class AdaLoraModel(LoraModel): model ([`transformers.PreTrainedModel`]): The model to be adapted. config ([`AdaLoraConfig`]): The configuration of the AdaLora model. adapter_name (`str`): The name of the adapter, defaults to `"default"`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. Returns: `torch.nn.Module`: The AdaLora model. diff --git a/src/peft/tuners/boft/model.py b/src/peft/tuners/boft/model.py index 11bd4c3ad21..4a4bb8158fb 100644 --- a/src/peft/tuners/boft/model.py +++ b/src/peft/tuners/boft/model.py @@ -49,6 +49,8 @@ class BOFTModel(BaseTuner): model ([`transformers.PreTrainedModel`]): The model to be adapted. config ([`BOFTConfig`]): The configuration of the BOFT model. adapter_name (`str`): The name of the adapter, defaults to `"default"`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. Returns: `torch.nn.Module`: The BOFT model. @@ -72,8 +74,8 @@ class BOFTModel(BaseTuner): prefix: str = "boft_" - def __init__(self, model, config, adapter_name) -> None: - super().__init__(model, config, adapter_name) + def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None: + super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) def _check_new_adapter_config(self, config: BOFTConfig) -> None: """ @@ -156,10 +158,12 @@ def _replace_module(self, parent, child_name, new_module, child): new_module.state = child.state new_module.to(child.weight.device) + meta = torch.device("meta") # dispatch to correct device for name, module in new_module.named_modules(): if self.prefix in name: - module.to(child.weight.device) + if not any(p.device == meta for p in module.parameters()): + module.to(child.weight.device) def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: for n, p in model.named_parameters(): diff --git a/src/peft/tuners/fourierft/model.py b/src/peft/tuners/fourierft/model.py index 969f6583a23..ce818d96a04 100644 --- a/src/peft/tuners/fourierft/model.py +++ b/src/peft/tuners/fourierft/model.py @@ -45,6 +45,8 @@ class FourierFTModel(BaseTuner): model ([`torch.nn.Module`]): The model to be adapted. config ([`FourierFTConfig`]): The configuration of the FourierFT model. adapter_name (`str`): The name of the adapter, defaults to `"default"`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. Returns: `torch.nn.Module`: The FourierFT model. @@ -56,8 +58,8 @@ class FourierFTModel(BaseTuner): prefix: str = "fourierft_" - def __init__(self, model, config, adapter_name) -> None: - super().__init__(model, config, adapter_name) + def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None: + super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) def _check_new_adapter_config(self, config: FourierFTConfig) -> None: """ @@ -142,10 +144,12 @@ def _replace_module(self, parent, child_name, new_module, child): new_module.state = child.state new_module.to(child.weight.device) + meta = torch.device("meta") # dispatch to correct device for name, module in new_module.named_modules(): if "fourierft_" in name: - module.to(child.weight.device) + if not any(p.device == meta for p in module.parameters()): + module.to(child.weight.device) def _mark_only_adapters_as_trainable(self, model: torch.nn.Module) -> None: for n, p in model.named_parameters(): diff --git a/src/peft/tuners/hra/model.py b/src/peft/tuners/hra/model.py index 64ad71d0745..1bfa0f12060 100644 --- a/src/peft/tuners/hra/model.py +++ b/src/peft/tuners/hra/model.py @@ -41,6 +41,8 @@ class HRAModel(BaseTuner): model (`torch.nn.Module`): The model to which the adapter tuner layers will be attached. config ([`HRAConfig`]): The configuration of the HRA model. adapter_name (`str`): The name of the adapter, defaults to `"default"`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. Returns: `torch.nn.Module`: The HRA model. @@ -158,10 +160,12 @@ def _replace_module(self, parent, child_name, new_module, child): new_module.state = child.state new_module.to(child.weight.device) + meta = torch.device("meta") # dispatch to correct device for name, module in new_module.named_modules(): if self.prefix in name: - module.to(child.weight.device) + if not any(p.device == meta for p in module.parameters()): + module.to(child.weight.device) def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: for n, p in model.named_parameters(): diff --git a/src/peft/tuners/ia3/model.py b/src/peft/tuners/ia3/model.py index 226408934d6..a58d1360dfb 100644 --- a/src/peft/tuners/ia3/model.py +++ b/src/peft/tuners/ia3/model.py @@ -45,6 +45,8 @@ class IA3Model(BaseTuner): model ([`~transformers.PreTrainedModel`]): The model to be adapted. config ([`IA3Config`]): The configuration of the (IA)^3 model. adapter_name (`str`): The name of the adapter, defaults to `"default"`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. Returns: `torch.nn.Module`: The (IA)^3 model. @@ -73,8 +75,8 @@ class IA3Model(BaseTuner): prefix: str = "ia3_" - def __init__(self, model, config, adapter_name): - super().__init__(model, config, adapter_name) + def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False): + super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) @staticmethod def _create_new_module(ia3_config, adapter_name, target, **kwargs): @@ -217,10 +219,12 @@ def _replace_module(self, parent, child_name, new_module, child): new_module.state = child.state new_module.to(child.weight.device) + meta = torch.device("meta") # dispatch to correct device for name, module in new_module.named_modules(): if self.prefix in name: - module.to(child.weight.device) + if not any(p.device == meta for p in module.parameters()): + module.to(child.weight.device) def __getattr__(self, name: str): """Forward missing attributes to the wrapped module.""" diff --git a/src/peft/tuners/ln_tuning/model.py b/src/peft/tuners/ln_tuning/model.py index 3028d7abcc6..3e16a7fad49 100644 --- a/src/peft/tuners/ln_tuning/model.py +++ b/src/peft/tuners/ln_tuning/model.py @@ -37,6 +37,8 @@ class LNTuningModel(BaseTuner): model ([`torch.nn.Module`]): The model to be adapted. config ([`LNTuningConfig`]): The configuration of the Lora model. adapter_name (`str`): The name of the adapter, defaults to `"default"`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + This option has no effect on LN tuning but exists for consistency with other PEFT methods. Returns: 'torch.nn.Module': The adapted model with LayerNorm tuned on. @@ -63,9 +65,9 @@ class LNTuningModel(BaseTuner): prefix: str = "ln_tuning_" - def __init__(self, model, config, adapter_name) -> None: + def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None: # self.adapter_name = adapter_name - super().__init__(model, config, adapter_name) + super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) def __getattr__(self, name: str): """Forward missing attributes to the wrapped module.""" diff --git a/src/peft/tuners/loha/model.py b/src/peft/tuners/loha/model.py index 6f1aaac9d59..e1cbc50b344 100644 --- a/src/peft/tuners/loha/model.py +++ b/src/peft/tuners/loha/model.py @@ -34,6 +34,8 @@ class LoHaModel(LycorisTuner): model (`torch.nn.Module`): The model to which the adapter tuner layers will be attached. config ([`LoHaConfig`]): The configuration of the LoHa model. adapter_name (`str`): The name of the adapter, defaults to `"default"`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. Returns: `torch.nn.Module`: The LoHa model. diff --git a/src/peft/tuners/lokr/model.py b/src/peft/tuners/lokr/model.py index eecad8dd13d..e8195a7a02b 100644 --- a/src/peft/tuners/lokr/model.py +++ b/src/peft/tuners/lokr/model.py @@ -35,6 +35,8 @@ class LoKrModel(LycorisTuner): model (`torch.nn.Module`): The model to which the adapter tuner layers will be attached. config ([`LoKrConfig`]): The configuration of the LoKr model. adapter_name (`str`): The name of the adapter, defaults to `"default"`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. Returns: `torch.nn.Module`: The LoKr model. diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 107a2593c5d..749f78051cc 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -72,6 +72,8 @@ class LoraModel(BaseTuner): model ([`torch.nn.Module`]): The model to be adapted. config ([`LoraConfig`]): The configuration of the Lora model. adapter_name (`str`): The name of the adapter, defaults to `"default"`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. Returns: `torch.nn.Module`: The Lora model. @@ -135,8 +137,8 @@ class LoraModel(BaseTuner): prefix: str = "lora_" - def __init__(self, model, config, adapter_name) -> None: - super().__init__(model, config, adapter_name) + def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None: + super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) def _check_new_adapter_config(self, config: LoraConfig) -> None: """ @@ -251,6 +253,7 @@ def _replace_module(self, parent, child_name, new_module, child): new_module.state = child.state new_module.to(child.weight.device) + meta = torch.device("meta") # dispatch to correct device for name, module in new_module.named_modules(): if (self.prefix in name) or ("ranknum" in name): @@ -263,7 +266,8 @@ def _replace_module(self, parent, child_name, new_module, child): if hasattr(child, "weight") else next(child.parameters()) ) - module.to(weight.device) + if not any(p.device == meta for p in module.parameters()): + module.to(weight.device) def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: for n, p in model.named_parameters(): diff --git a/src/peft/tuners/lycoris_utils.py b/src/peft/tuners/lycoris_utils.py index 01e7cc6b9a4..ed7c1a50f92 100644 --- a/src/peft/tuners/lycoris_utils.py +++ b/src/peft/tuners/lycoris_utils.py @@ -187,13 +187,21 @@ def update_layer(self, adapter_name: str, r: int, alpha: float, **kwargs): ... class LycorisTuner(BaseTuner): r""" A base tuner for LyCORIS like adapters + + Args: + model ([`torch.nn.Module`]): The model to be adapted. + config ([`LoraConfig`]): The configuration of the Lora model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. + """ prefix: str layers_mapping: dict[type[torch.nn.Module], type[LycorisLayer]] - def __init__(self, model, config, adapter_name): - super().__init__(model, config, adapter_name) + def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False): + super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) def __getattr__(self, name: str): """Forward missing attributes to the wrapped module.""" @@ -290,10 +298,12 @@ def _replace_module(self, parent, child_name, new_module, child): new_module.state = child.state new_module.to(child.weight.device) + meta = torch.device("meta") # dispatch to correct device for name, module in new_module.named_modules(): if self.prefix in name: - module.to(child.weight.device) + if not any(p.device == meta for p in module.parameters()): + module.to(child.weight.device) def _set_adapter_layers(self, enabled=True): for module in self.model.modules(): diff --git a/src/peft/tuners/oft/model.py b/src/peft/tuners/oft/model.py index fd96325c6f0..d2530295b65 100644 --- a/src/peft/tuners/oft/model.py +++ b/src/peft/tuners/oft/model.py @@ -32,6 +32,8 @@ class OFTModel(LycorisTuner): model (`torch.nn.Module`): The model to which the adapter tuner layers will be attached. config ([`OFTConfig`]): The configuration of the OFT model. adapter_name (`str`): The name of the adapter, defaults to `"default"`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. Returns: `torch.nn.Module`: The OFT model. diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index f1852362873..2f195b2cbef 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -20,10 +20,11 @@ import textwrap import warnings from abc import ABC, abstractmethod -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from typing import Any, Optional, Union import torch +from accelerate import init_empty_weights from accelerate.hooks import AlignDevicesHook from accelerate.utils import named_module_tensors, offload_state_dict from torch import nn @@ -155,6 +156,7 @@ def __init__( model, peft_config: Union[PeftConfig, dict[str, PeftConfig]], adapter_name: str, + low_cpu_mem_usage: bool = False, ) -> None: super().__init__() @@ -179,7 +181,7 @@ def __init__( self.active_adapter: str | list[str] = adapter_name self._pre_injection_hook(self.model, self.peft_config[adapter_name], adapter_name) if peft_config != PeftType.XLORA or peft_config[adapter_name] != PeftType.XLORA: - self.inject_adapter(self.model, adapter_name) + self.inject_adapter(self.model, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) # Copy the peft_config in the injected model. self.model.peft_config = self.peft_config @@ -399,7 +401,9 @@ def _check_merge_allowed(self): + example_code ) - def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_dtype: bool = True) -> None: + def inject_adapter( + self, model: nn.Module, adapter_name: str, autocast_adapter_dtype: bool = True, low_cpu_mem_usage: bool = False + ) -> None: r""" Creates adapter layers and replaces the target modules with the adapter layers. This method is called under the hood by `peft.mapping.get_peft_model` if a non-prompt tuning adapter class is passed. @@ -413,6 +417,9 @@ def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_d The adapter name. autocast_adapter_dtype (`bool`, *optional*): Whether to autocast the adapter dtype. Defaults to `True`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. + """ peft_config = self.peft_config[adapter_name] # Note: If possible, all checks should be performed *at the start of this method*. @@ -482,7 +489,9 @@ def inject_adapter(self, model: nn.Module, adapter_name: str, autocast_adapter_d self.targeted_module_names.append(key) is_target_modules_in_base_model = True parent, target, target_name = _get_submodules(model, key) - self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key) + ctx = init_empty_weights if low_cpu_mem_usage else nullcontext + with ctx(): + self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key) tied_target_modules = self._get_tied_target_modules(model=model) if tied_target_modules: @@ -792,6 +801,8 @@ def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optio # no break encountered: could not determine the device return + meta = torch.device("meta") + # loop through all potential adapter layers and move them to the device of the base layer; be careful to only # move this specific adapter to the device, as the other adapters could be on different devices # see #1639 @@ -801,6 +812,9 @@ def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optio continue if adapter_name not in adapter_layer: continue + if any(p.device == meta for p in adapter_layer.parameters()): + continue + if weight.dtype.is_floating_point or weight.dtype.is_complex: adapter_layer[adapter_name] = adapter_layer[adapter_name].to(device, dtype=dtype) else: diff --git a/src/peft/tuners/vblora/model.py b/src/peft/tuners/vblora/model.py index 5376b8563f5..9460549d392 100644 --- a/src/peft/tuners/vblora/model.py +++ b/src/peft/tuners/vblora/model.py @@ -40,6 +40,8 @@ class VBLoRAModel(BaseTuner): model ([`~transformers.PreTrainedModel`]): The model to be adapted. config ([`VBLoRAConfig`]): The configuration of the VBLoRA model. adapter_name (`str`): The name of the adapter, defaults to `"default"`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. Returns: `torch.nn.Module`: The VBLoRA model. @@ -69,8 +71,8 @@ class VBLoRAModel(BaseTuner): prefix: str = "vblora_" - def __init__(self, model, config, adapter_name) -> None: - super().__init__(model, config, adapter_name) + def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None: + super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) def _init_vblora_vector_bank(self, config: VBLoRAConfig, adapter_name: str) -> None: vblora_vector_bank = torch.zeros(config.num_vectors, config.vector_length) @@ -166,10 +168,12 @@ def _replace_module(parent, child_name, new_module, child): new_module.state = child.state new_module.to(child.weight.device) + meta = torch.device("meta") # dispatch to correct device for name, module in new_module.named_modules(): if "vblora_" in name: - module.to(child.weight.device) + if not any(p.device == meta for p in module.parameters()): + module.to(child.weight.device) def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: for n, p in model.named_parameters(): diff --git a/src/peft/tuners/vera/model.py b/src/peft/tuners/vera/model.py index 8ef35067384..d268d2ae0a1 100644 --- a/src/peft/tuners/vera/model.py +++ b/src/peft/tuners/vera/model.py @@ -76,6 +76,8 @@ class VeraModel(BaseTuner): model ([`~transformers.PreTrainedModel`]): The model to be adapted. config ([`VeraConfig`]): The configuration of the Vera model. adapter_name (`str`): The name of the adapter, defaults to `"default"`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. Returns: `torch.nn.Module`: The Vera model. @@ -98,8 +100,8 @@ class VeraModel(BaseTuner): prefix: str = "vera_lambda" - def __init__(self, model, config, adapter_name) -> None: - super().__init__(model, config, adapter_name) + def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None: + super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) def _find_dim(self, config) -> tuple[int, int]: """ @@ -255,10 +257,12 @@ def _replace_module(parent, child_name, new_module, child): new_module.state = child.state new_module.to(child.weight.device) + meta = torch.device("meta") # dispatch to correct device for name, module in new_module.named_modules(): if "vera_" in name: - module.to(child.weight.device) + if not any(p.device == meta for p in module.parameters()): + module.to(child.weight.device) def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: for n, p in model.named_parameters(): diff --git a/src/peft/utils/save_and_load.py b/src/peft/utils/save_and_load.py index 0f575df9b59..5b40b4314c6 100644 --- a/src/peft/utils/save_and_load.py +++ b/src/peft/utils/save_and_load.py @@ -306,7 +306,11 @@ def _find_mismatched_keys( def set_peft_model_state_dict( - model, peft_model_state_dict, adapter_name="default", ignore_mismatched_sizes: bool = False + model, + peft_model_state_dict, + adapter_name="default", + ignore_mismatched_sizes: bool = False, + low_cpu_mem_usage: bool = False, ): """ Set the state dict of the Peft model. @@ -320,6 +324,10 @@ def set_peft_model_state_dict( The name of the adapter whose state dict should be set. ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): Whether to ignore mismatched in the state dict. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + This argument must be `True` if the `model` was loaded with adapter weights on the meta device, e.g. after + calling `inject_adapter_in_model` with `low_cpu_mem_usage=True`. Otherwise, leave it as `False`. + """ config = model.peft_config[adapter_name] state_dict = {} @@ -433,7 +441,11 @@ def renamed_dora_weights(k): peft_model_state_dict, mismatched_keys = _find_mismatched_keys( model, peft_model_state_dict, ignore_mismatched_sizes=ignore_mismatched_sizes ) - load_result = model.load_state_dict(peft_model_state_dict, strict=False) + if low_cpu_mem_usage: + load_result = model.load_state_dict(peft_model_state_dict, strict=False, assign=True) + else: + load_result = model.load_state_dict(peft_model_state_dict, strict=False) + if config.is_prompt_learning: model.prompt_encoder[adapter_name].embedding.load_state_dict( {"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 2fe5f0c99f0..3f3a97304f4 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -853,6 +853,10 @@ def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs): def test_save_pretrained_pickle(self, test_name, model_id, config_cls, config_kwargs): self._test_save_pretrained(model_id, config_cls, config_kwargs, safe_serialization=False) + @parameterized.expand(TEST_CASES) + def test_load_model_low_cpu_mem_usage(self, test_name, model_id, config_cls, config_kwargs): + self._test_load_model_low_cpu_mem_usage(model_id, config_cls, config_kwargs) + @parameterized.expand(TEST_CASES) def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs): self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs) diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 340682322ef..cc54003350c 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -20,6 +20,7 @@ import pytest import torch from huggingface_hub.utils import reset_sessions +from safetensors.torch import load_file from scipy import stats from torch import nn from transformers import AutoModelForCausalLM @@ -39,6 +40,8 @@ VBLoRAConfig, VeraConfig, get_peft_model, + inject_adapter_in_model, + set_peft_model_state_dict, ) from peft.utils import infer_device @@ -1336,3 +1339,176 @@ def test_warning_name_custom_model_with_custom_name(self, custom_module, recwarn get_peft_model(custom_module, LoraConfig(target_modules=["lin"], base_model_name_or_path=custom_name)) msg = f"was renamed from '{custom_name}' to 'foobar'" assert any(msg in str(warning.message) for warning in recwarn.list) + + +class TestLowCpuMemUsage: + """Test for the low CPU memory usage option for loading PEFT models. + + Note that we have `test_load_model_low_cpu_mem_usage` in the custom model and stable diffusion tests. Those are + broad tests (i.e. testing all the supported PEFT methods) but not very deep (only testing if loading works and the + device is correctly set). The test class here goes deeper but only tests LoRA, as checking all PEFT methods would + be too much. + + """ + + # test on CPU and optionally on accelerator device + devices = ["cpu"] + _device = infer_device() + if _device != "cpu": + devices.append(_device) + + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + + def get_model(self): + return AutoModelForCausalLM.from_pretrained(self.model_id) + + @pytest.fixture(scope="class") + def lora_config(self): + return LoraConfig(init_lora_weights=False, target_modules="all-linear") + + @pytest.fixture(scope="class") + def lora_path(self, tmp_path_factory, lora_config): + torch.manual_seed(0) + tmp_path = tmp_path_factory.mktemp("lora") + model = self.get_model() + model = get_peft_model(model, lora_config) + model.save_pretrained(tmp_path) + return tmp_path + + @pytest.fixture(scope="class") + def inputs(self): + return {"input_ids": torch.randint(0, 100, (1, 10)), "attention_mask": torch.ones(1, 10)} + + @pytest.mark.parametrize("device", devices) + def test_from_pretrained_low_cpu_mem_usage_works(self, device, inputs, lora_path): + model = self.get_model().to(device) + inputs = {k: v.to(device) for k, v in inputs.items()} + model = PeftModel.from_pretrained(model, lora_path, torch_device=device).eval() + device_set_not_low_cpu_mem = {p.device.type for p in model.parameters()} + logits_not_low_cpu_mem = model(**inputs).logits + + del model + + model = self.get_model().to(device) + model = PeftModel.from_pretrained(model, lora_path, low_cpu_mem_usage=True, torch_device=device).eval() + device_set_low_cpu_mem = {p.device.type for p in model.parameters()} + logits_low_cpu_mem = model(**inputs).logits + + assert device_set_low_cpu_mem == device_set_not_low_cpu_mem + assert torch.allclose(logits_low_cpu_mem, logits_not_low_cpu_mem) + + @pytest.mark.parametrize("device", devices) + def test_load_adapter_low_cpu_mem_usage_works(self, device, inputs, lora_path, lora_config): + model = self.get_model().to(device) + inputs = {k: v.to(device) for k, v in inputs.items()} + + torch.manual_seed(0) + model = get_peft_model(model, lora_config) + model.load_adapter(lora_path, adapter_name="other", torch_device=device) + model.set_adapter("other") + model.eval() + device_set_not_low_cpu_mem = {p.device.type for p in model.parameters()} + logits_not_low_cpu_mem = model(**inputs).logits + + del model + + model = self.get_model().to(device) + torch.manual_seed(0) + model = get_peft_model(model, lora_config) + model.load_adapter(lora_path, adapter_name="other", low_cpu_mem_usage=True, torch_device=device) + model.set_adapter("other") + model.eval() + device_set_low_cpu_mem = {p.device.type for p in model.parameters()} + logits_low_cpu_mem = model(**inputs).logits + + assert device_set_low_cpu_mem == device_set_not_low_cpu_mem + assert torch.allclose(logits_low_cpu_mem, logits_not_low_cpu_mem) + + @pytest.mark.parametrize("device", devices) + def test_inject_adapter_low_cpu_mem_usage_works(self, device, inputs, lora_path, lora_config): + # external libs like transformers and diffusers use inject_adapter_in_model, let's check that this also works + model = self.get_model().to(device) + inputs = {k: v.to(device) for k, v in inputs.items()} + + torch.manual_seed(0) + model = get_peft_model(model, lora_config) + model.load_adapter(lora_path, adapter_name="other", torch_device=device) + model.set_adapter("other") + model.eval() + device_set_not_low_cpu_mem = {p.device.type for p in model.parameters()} + logits_not_low_cpu_mem = model(**inputs).logits + + del model + + torch.manual_seed(0) + model = self.get_model().to(device) + inject_adapter_in_model(lora_config, model, low_cpu_mem_usage=True) + device_set_before_loading = {p.device.type for p in model.parameters()} + # at this stage, lora weights are still on meta device + assert device_set_before_loading == {"meta", device} + + state_dict = load_file(lora_path / "adapter_model.safetensors") + remapped_dict = {} + prefix = "base_model.model." + for key, val in state_dict.items(): + new_key = key[len(prefix) :] + remapped_dict[new_key] = val.to(device) + errors = set_peft_model_state_dict(model, remapped_dict, low_cpu_mem_usage=True) + # sanity check: no unexpected keys + assert not errors.unexpected_keys + + model.eval() + device_set_low_cpu_mem = {p.device.type for p in model.parameters()} + logits_low_cpu_mem = model(**inputs).logits + + assert device_set_low_cpu_mem == device_set_not_low_cpu_mem + assert torch.allclose(logits_low_cpu_mem, logits_not_low_cpu_mem) + + ############################ + # tests for PeftMixedModel # + ############################ + + @pytest.mark.parametrize("device", devices) + def test_mixed_model_from_pretrained_low_cpu_mem_usage_works(self, device, inputs, lora_path): + model = self.get_model().to(device) + inputs = {k: v.to(device) for k, v in inputs.items()} + model = PeftMixedModel.from_pretrained(model, lora_path, torch_device=device).eval() + device_set_not_low_cpu_mem = {p.device.type for p in model.parameters()} + logits_not_low_cpu_mem = model(**inputs).logits + + del model + + model = self.get_model().to(device) + model = PeftMixedModel.from_pretrained(model, lora_path, low_cpu_mem_usage=True, torch_device=device).eval() + device_set_low_cpu_mem = {p.device.type for p in model.parameters()} + logits_low_cpu_mem = model(**inputs).logits + + assert device_set_low_cpu_mem == device_set_not_low_cpu_mem + assert torch.allclose(logits_low_cpu_mem, logits_not_low_cpu_mem) + + @pytest.mark.parametrize("device", devices) + def test_mixed_model_load_adapter_low_cpu_mem_usage_works(self, device, inputs, lora_path, lora_config): + model = self.get_model().to(device) + inputs = {k: v.to(device) for k, v in inputs.items()} + + torch.manual_seed(0) + model = PeftModel.from_pretrained(model, lora_path) + model.load_adapter(lora_path, adapter_name="other", torch_device=device) + model.set_adapter("other") + model.eval() + device_set_not_low_cpu_mem = {p.device.type for p in model.parameters()} + logits_not_low_cpu_mem = model(**inputs).logits + + del model + + model = self.get_model().to(device) + torch.manual_seed(0) + model = PeftModel.from_pretrained(model, lora_path) + model.load_adapter(lora_path, adapter_name="other", low_cpu_mem_usage=True, torch_device=device) + model.set_adapter("other") + model.eval() + device_set_low_cpu_mem = {p.device.type for p in model.parameters()} + logits_low_cpu_mem = model(**inputs).logits + + assert device_set_low_cpu_mem == device_set_not_low_cpu_mem + assert torch.allclose(logits_low_cpu_mem, logits_not_low_cpu_mem) diff --git a/tests/test_stablediffusion.py b/tests/test_stablediffusion.py index f0217d670ac..99dbced4fde 100644 --- a/tests/test_stablediffusion.py +++ b/tests/test_stablediffusion.py @@ -19,7 +19,18 @@ from diffusers import StableDiffusionPipeline from parameterized import parameterized -from peft import BOFTConfig, HRAConfig, LoHaConfig, LoraConfig, OFTConfig, get_peft_model +from peft import ( + BOFTConfig, + HRAConfig, + LoHaConfig, + LoraConfig, + OFTConfig, + get_peft_model, + get_peft_model_state_dict, + inject_adapter_in_model, + set_peft_model_state_dict, +) +from peft.tuners.tuners_utils import BaseTunerLayer from .testing_common import ClassInstantier, PeftCommonTester from .testing_utils import temp_seed @@ -260,3 +271,47 @@ def test_add_weighted_adapter_base_unchanged(self, test_name, model_id, config_c ) def test_disable_adapter(self, test_name, model_id, config_cls, config_kwargs): self._test_disable_adapter(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftStableDiffusionTestConfigManager.get_grid_parameters( + {"model_ids": PEFT_DIFFUSERS_SD_MODELS_TO_TEST}, + ) + ) + def test_load_model_low_cpu_mem_usage(self, test_name, model_id, config_cls, config_kwargs): + # Instantiate model & adapters + pipe = self.instantiate_sd_peft(model_id, config_cls, config_kwargs) + + te_state_dict = get_peft_model_state_dict(pipe.text_encoder) + unet_state_dict = get_peft_model_state_dict(pipe.unet) + + del pipe + pipe = self.instantiate_sd_peft(model_id, config_cls, config_kwargs) + + config_kwargs = config_kwargs.copy() + text_encoder_kwargs = config_kwargs.pop("text_encoder") + unet_kwargs = config_kwargs.pop("unet") + # the remaining config kwargs should be applied to both configs + for key, val in config_kwargs.items(): + text_encoder_kwargs[key] = val + unet_kwargs[key] = val + + config_text_encoder = config_cls(**text_encoder_kwargs) + config_unet = config_cls(**unet_kwargs) + + # check text encoder + inject_adapter_in_model(config_text_encoder, pipe.text_encoder, low_cpu_mem_usage=True) + # sanity check that the adapter was applied: + assert any(isinstance(module, BaseTunerLayer) for module in pipe.text_encoder.modules()) + + assert "meta" in {p.device.type for p in pipe.text_encoder.parameters()} + set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True) + assert "meta" not in {p.device.type for p in pipe.text_encoder.parameters()} + + # check unet + inject_adapter_in_model(config_unet, pipe.unet, low_cpu_mem_usage=True) + # sanity check that the adapter was applied: + assert any(isinstance(module, BaseTunerLayer) for module in pipe.unet.modules()) + + assert "meta" in {p.device.type for p in pipe.unet.parameters()} + set_peft_model_state_dict(pipe.unet, unet_state_dict, low_cpu_mem_usage=True) + assert "meta" not in {p.device.type for p in pipe.unet.parameters()} diff --git a/tests/testing_common.py b/tests/testing_common.py index f564f4e5bd2..5f3ec890175 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -48,6 +48,7 @@ VeraConfig, get_peft_model, get_peft_model_state_dict, + inject_adapter_in_model, prepare_model_for_kbit_training, ) from peft.tuners.lora import LoraLayer @@ -304,6 +305,46 @@ def make_inputs_require_grad(module, input, output): assert dummy_output.requires_grad + def _test_load_model_low_cpu_mem_usage(self, model_id, config_cls, config_kwargs): + # Ensure that low_cpu_mem_usage=True works for from_pretrained and load_adapter and that the resulting model's + # parameters are on the correct device. + model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config) + + # note: not using the context manager here because it fails on Windows CI for some reason + tmp_dirname = tempfile.mkdtemp() + try: + model.save_pretrained(tmp_dirname) + + model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) + model = PeftModel.from_pretrained( + model, tmp_dirname, torch_device=self.torch_device, low_cpu_mem_usage=True + ) + assert {p.device.type for p in model.parameters()} == {self.torch_device} + + model.load_adapter(tmp_dirname, adapter_name="other", low_cpu_mem_usage=True) + assert {p.device.type for p in model.parameters()} == {self.torch_device} + finally: + try: + shutil.rmtree(tmp_dirname) + except PermissionError: + # windows error + pass + + # also test injecting directly + del model + model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) + inject_adapter_in_model(config, model, low_cpu_mem_usage=True) # check that there is no error + + if not isinstance(config, LNTuningConfig): + # LN tuning does not add adapter layers that could be on meta device, it only changes the requires_grad. + # Therefore, there is no meta device for LN tuning. + assert "meta" in {p.device.type for p in model.parameters()} + def _test_save_pretrained(self, model_id, config_cls, config_kwargs, safe_serialization=True): # ensure that the weights are randomly initialized if issubclass(config_cls, LoraConfig):