diff --git a/src/peft/mapping.py b/src/peft/mapping.py index fec9e28df8..f9004bad09 100644 --- a/src/peft/mapping.py +++ b/src/peft/mapping.py @@ -35,7 +35,7 @@ PromptEncoderConfig, PromptTuningConfig, ) -from .utils import PromptLearningConfig +from .utils import PromptLearningConfig, _prepare_prompt_learning_config if TYPE_CHECKING: @@ -75,48 +75,6 @@ def get_peft_config(config_dict: Dict[str, Any]): return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict) -def _prepare_prompt_learning_config(peft_config: PeftConfig, model_config: Dict[str, Any]): - if peft_config.num_layers is None: - if "num_hidden_layers" in model_config: - num_layers = model_config["num_hidden_layers"] - elif "num_layers" in model_config: - num_layers = model_config["num_layers"] - elif "n_layer" in model_config: - num_layers = model_config["n_layer"] - else: - raise ValueError("Please specify `num_layers` in `peft_config`") - peft_config.num_layers = num_layers - - if peft_config.token_dim is None: - if "hidden_size" in model_config: - token_dim = model_config["hidden_size"] - elif "n_embd" in model_config: - token_dim = model_config["n_embd"] - elif "d_model" in model_config: - token_dim = model_config["d_model"] - else: - raise ValueError("Please specify `token_dim` in `peft_config`") - peft_config.token_dim = token_dim - - if peft_config.num_attention_heads is None: - if "num_attention_heads" in model_config: - num_attention_heads = model_config["num_attention_heads"] - elif "n_head" in model_config: - num_attention_heads = model_config["n_head"] - elif "num_heads" in model_config: - num_attention_heads = model_config["num_heads"] - elif "encoder_attention_heads" in model_config: - num_attention_heads = model_config["encoder_attention_heads"] - else: - raise ValueError("Please specify `num_attention_heads` in `peft_config`") - peft_config.num_attention_heads = num_attention_heads - - if getattr(peft_config, "encoder_hidden_size", None) is None: - setattr(peft_config, "encoder_hidden_size", peft_config.token_dim) - - return peft_config - - def get_peft_model(model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> PeftModel: """ Returns a Peft model object from a model and a config. diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 2eef2c141b..ce65d2214e 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -20,7 +20,7 @@ import warnings from contextlib import contextmanager from copy import deepcopy -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch from accelerate import dispatch_model, infer_auto_device_map @@ -53,6 +53,7 @@ PeftType, PromptLearningConfig, TaskType, + _prepare_prompt_learning_config, _set_adapter, _set_trainable, add_library_to_model_card, @@ -118,7 +119,13 @@ def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name if getattr(model, "is_gradient_checkpointing", True): model = self._prepare_model_for_gradient_checkpointing(model) - def save_pretrained(self, save_directory: str, safe_serialization: bool = False, **kwargs: Any): + def save_pretrained( + self, + save_directory: str, + safe_serialization: bool = False, + selected_adapters: Optional[List[str]] = None, + **kwargs: Any, + ): r""" This function saves the adapter model and the adapter configuration files to a directory, so that it can be reloaded using the [`LoraModel.from_pretrained`] class method, and also used by the [`LoraModel.push_to_hub`] @@ -133,10 +140,24 @@ def save_pretrained(self, save_directory: str, safe_serialization: bool = False, """ if os.path.isfile(save_directory): raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file") + + if selected_adapters is None: + selected_adapters = list(self.peft_config.keys()) + else: + if any( + selected_adapter_name not in list(self.peft_config.keys()) + for selected_adapter_name in selected_adapters + ): + raise ValueError( + f"You passed an invalid `selected_adapters` arguments, current supported adapter names are" + f" {list(self.peft_config.keys())} - got {selected_adapters}." + ) + os.makedirs(save_directory, exist_ok=True) self.create_or_update_model_card(save_directory) - for adapter_name, peft_config in self.peft_config.items(): + for adapter_name in selected_adapters: + peft_config = self.peft_config[adapter_name] # save only the trainable weights output_state_dict = get_peft_model_state_dict( self, state_dict=kwargs.get("state_dict", None), adapter_name=adapter_name @@ -146,7 +167,9 @@ def save_pretrained(self, save_directory: str, safe_serialization: bool = False, if safe_serialization: safe_save_file( - output_state_dict, os.path.join(output_dir, SAFETENSORS_WEIGHTS_NAME), metadata={"format": "pt"} + output_state_dict, + os.path.join(output_dir, SAFETENSORS_WEIGHTS_NAME), + metadata={"format": "pt"}, ) else: torch.save(output_state_dict, os.path.join(output_dir, WEIGHTS_NAME)) @@ -234,8 +257,9 @@ def from_pretrained( def _setup_prompt_encoder(self, adapter_name: str): config = self.peft_config[adapter_name] - self.prompt_encoder = torch.nn.ModuleDict({}) - self.prompt_tokens = {} + if not hasattr(self, "prompt_encoder"): + self.prompt_encoder = torch.nn.ModuleDict({}) + self.prompt_tokens = {} transformer_backbone = None for name, module in self.base_model.named_children(): for param in module.parameters(): @@ -412,6 +436,12 @@ def add_adapter(self, adapter_name: str, peft_config: PeftConfig): ) self.peft_config[adapter_name] = peft_config if isinstance(peft_config, PromptLearningConfig): + if hasattr(self.config, "to_dict"): + dict_config = self.config.to_dict() + else: + dict_config = self.config + + peft_config = _prepare_prompt_learning_config(peft_config, dict_config) self._setup_prompt_encoder(adapter_name) else: self.base_model.add_adapter(adapter_name, peft_config) diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index a45f6c9a76..ca39b60350 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -40,6 +40,7 @@ _set_adapter, _freeze_adapter, ModulesToSaveWrapper, + _prepare_prompt_learning_config, ) from .hub_utils import hub_file_exists from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index b5b1e70bce..dfe40caf73 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -178,6 +178,48 @@ def _set_adapter(model, adapter_name): module.active_adapter = adapter_name +def _prepare_prompt_learning_config(peft_config, model_config): + if peft_config.num_layers is None: + if "num_hidden_layers" in model_config: + num_layers = model_config["num_hidden_layers"] + elif "num_layers" in model_config: + num_layers = model_config["num_layers"] + elif "n_layer" in model_config: + num_layers = model_config["n_layer"] + else: + raise ValueError("Please specify `num_layers` in `peft_config`") + peft_config.num_layers = num_layers + + if peft_config.token_dim is None: + if "hidden_size" in model_config: + token_dim = model_config["hidden_size"] + elif "n_embd" in model_config: + token_dim = model_config["n_embd"] + elif "d_model" in model_config: + token_dim = model_config["d_model"] + else: + raise ValueError("Please specify `token_dim` in `peft_config`") + peft_config.token_dim = token_dim + + if peft_config.num_attention_heads is None: + if "num_attention_heads" in model_config: + num_attention_heads = model_config["num_attention_heads"] + elif "n_head" in model_config: + num_attention_heads = model_config["n_head"] + elif "num_heads" in model_config: + num_attention_heads = model_config["num_heads"] + elif "encoder_attention_heads" in model_config: + num_attention_heads = model_config["encoder_attention_heads"] + else: + raise ValueError("Please specify `num_attention_heads` in `peft_config`") + peft_config.num_attention_heads = num_attention_heads + + if getattr(peft_config, "encoder_hidden_size", None) is None: + setattr(peft_config, "encoder_hidden_size", peft_config.token_dim) + + return peft_config + + def fsdp_auto_wrap_policy(model): import functools import os diff --git a/tests/test_adaption_prompt.py b/tests/test_adaption_prompt.py index b0488d66ed..1f666e51d2 100644 --- a/tests/test_adaption_prompt.py +++ b/tests/test_adaption_prompt.py @@ -160,6 +160,56 @@ def test_save_pretrained(self) -> None: # check if `config.json` is not present self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) + def test_save_pretrained_selected_adapters(self) -> None: + seed = 420 + torch.manual_seed(seed) + model = LlamaForCausalLM(self._create_test_llama_config()) + config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + model = get_peft_model(model, config) + model = model.to(self.torch_device) + + new_adapter_config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") + model.add_adapter("new_adapter", new_adapter_config) + + with tempfile.TemporaryDirectory() as tmp_dirname: + model.save_pretrained(tmp_dirname) + + torch.manual_seed(seed) + model_from_pretrained = LlamaForCausalLM(self._create_test_llama_config()) + model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname) + + model_from_pretrained.load_adapter(tmp_dirname, "new_adapter") + + # check if the state dicts are equal + state_dict = get_peft_model_state_dict(model) + state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained) + + # check if same keys + self.assertEqual(state_dict.keys(), state_dict_from_pretrained.keys()) + + # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate). + self.assertEqual(len(list(state_dict.keys())), 4) + + # check if tensors equal + for key in state_dict.keys(): + self.assertTrue( + torch.allclose( + state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device) + ) + ) + + # check if `adapter_model.bin` is present + self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin"))) + + # check if `adapter_config.json` is present + self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))) + + # check if `pytorch_model.bin` is not present + self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "pytorch_model.bin"))) + + # check if `config.json` is not present + self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) + def test_generate(self) -> None: model = LlamaForCausalLM(self._create_test_llama_config()) config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM") diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index cb795a1932..f44e355a9e 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -81,6 +81,10 @@ def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs): self._test_save_pretrained(model_id, config_cls, config_kwargs) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_save_pretrained_selected_adapters(self, test_name, model_id, config_cls, config_kwargs): + self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs): self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs) diff --git a/tests/test_encoder_decoder_models.py b/tests/test_encoder_decoder_models.py index 59c3ae3d83..0b108d58d0 100644 --- a/tests/test_encoder_decoder_models.py +++ b/tests/test_encoder_decoder_models.py @@ -67,6 +67,10 @@ def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs): self._test_save_pretrained(model_id, config_cls, config_kwargs) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_save_pretrained_selected_adapters(self, test_name, model_id, config_cls, config_kwargs): + self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs): self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs) diff --git a/tests/test_feature_extraction_models.py b/tests/test_feature_extraction_models.py index 6da6fcf972..2dee480568 100644 --- a/tests/test_feature_extraction_models.py +++ b/tests/test_feature_extraction_models.py @@ -84,6 +84,10 @@ def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs): self._test_save_pretrained(model_id, config_cls, config_kwargs) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_save_pretrained_selected_adapters(self, test_name, model_id, config_cls, config_kwargs): + self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs) + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs): self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs) diff --git a/tests/testing_common.py b/tests/testing_common.py index a17766eecb..eeed94cf22 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -267,6 +267,66 @@ def _test_save_pretrained(self, model_id, config_cls, config_kwargs): # check if `config.json` is not present self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) + def _test_save_pretrained_selected_adapters(self, model_id, config_cls, config_kwargs): + model = self.transformers_class.from_pretrained(model_id) + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config) + model = model.to(self.torch_device) + + new_adapter_config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + + model.add_adapter("new_adapter", new_adapter_config) + + with tempfile.TemporaryDirectory() as tmp_dirname: + model.save_pretrained(tmp_dirname) + + model_from_pretrained = self.transformers_class.from_pretrained(model_id) + model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname) + + model_from_pretrained.load_adapter(tmp_dirname, "new_adapter") + + # check if the state dicts are equal + state_dict = get_peft_model_state_dict(model) + state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained) + + # check if same keys + self.assertEqual(state_dict.keys(), state_dict_from_pretrained.keys()) + + # check if tensors equal + for key in state_dict.keys(): + self.assertTrue( + torch.allclose( + state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device) + ) + ) + + # check if `adapter_model.bin` is present + self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin"))) + + # check if `adapter_config.json` is present + self.assertTrue(os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))) + + # check if `pytorch_model.bin` is not present + self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "pytorch_model.bin"))) + + # check if `config.json` is not present + self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) + + with tempfile.TemporaryDirectory() as tmp_dirname: + model.save_pretrained(tmp_dirname, selected_adapters=["default"]) + + model_from_pretrained = self.transformers_class.from_pretrained(model_id) + model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname) + + self.assertTrue("default" in model_from_pretrained.peft_config.keys()) + self.assertTrue("new_adapter" not in model_from_pretrained.peft_config.keys()) + def _test_from_pretrained_config_construction(self, model_id, config_cls, config_kwargs): model = self.transformers_class.from_pretrained(model_id) config = config_cls(base_model_name_or_path=model_id, **config_kwargs)