diff --git a/src/peft/tuners/adalora.py b/src/peft/tuners/adalora.py index 6aa803f628..f31c33fc30 100644 --- a/src/peft/tuners/adalora.py +++ b/src/peft/tuners/adalora.py @@ -64,7 +64,7 @@ def __post_init__(self): class AdaLoraModel(LoraModel): """ Creates AdaLoRA (Adaptive LoRA) model from a pretrained transformers model. Paper: - https://openreview.net/pdf?id=lq62uWRJjiY + https://openreview.net/forum?id=lq62uWRJjiY Args: model ([`transformers.PreTrainedModel`]): The model to be adapted. @@ -149,7 +149,7 @@ def _find_and_replace(self, adapter_name): if not is_target_modules_in_base_model: is_target_modules_in_base_model = True parent, target, target_name = _get_submodules(self.model, key) - bias = target.bias is not None + bias = hasattr(target, "bias") and target.bias is not None if isinstance(target, LoraLayer): target.update_layer( adapter_name, @@ -183,6 +183,9 @@ def _find_and_replace(self, adapter_name): new_module = SVDLinear4bit( adapter_name, target.in_features, target.out_features, bias=bias, **fourbit_kwargs ) + elif isinstance(target, (nn.ModuleList, nn.ModuleDict)): + # it's not applicable to replace whole module lists or module dicts + continue else: if isinstance(target, torch.nn.Linear): in_features, out_features = target.in_features, target.out_features @@ -352,11 +355,11 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) # Actual trainable parameters # Right singular vectors - self.lora_A.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.zeros(r, self.in_features))})) + self.lora_A.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.randn(r, self.in_features))})) # Singular values - self.lora_E.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.zeros(r, 1))})) + self.lora_E.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.randn(r, 1))})) # Left singular vectors - self.lora_B.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.zeros(self.out_features, r))})) + self.lora_B.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.randn(self.out_features, r))})) # The current rank self.ranknum.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.zeros(1), requires_grad=False)})) self.ranknum[adapter_name].data.fill_(float(r)) diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 2d9644889d..0ff47bcfad 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -358,18 +358,20 @@ def transpose(weight, fan_in_fan_out): "t5": ["q", "k", "v", "o", "wi", "wo"], "mt5": ["q", "k", "v", "o", "wi_0", "wi_1", "wo"], "bart": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], - # "gpt2": ["c_attn"], - # "bloom": ["query_key_value"], + "gpt2": ["c_attn"], + "bloom": ["query_key_value"], "opt": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], - # "gptj": ["q_proj", "v_proj"], - # "gpt_neox": ["query_key_value"], - # "gpt_neo": ["q_proj", "v_proj"], - # "bert": ["query", "value"], + "gptj": ["q_proj", "v_proj"], + "gpt_neox": ["query_key_value"], + "gpt_neo": ["q_proj", "v_proj"], + "llama": ["q_proj", "v_proj"], + "bert": ["query", "value"], "roberta": ["query", "key", "value", "dense"], # "xlm-roberta": ["query", "value"], # "electra": ["query", "value"], "deberta-v2": ["query_proj", "key_proj", "value_proj", "dense"], - # "deberta": ["in_proj"], + "gpt_bigcode": ["c_attn"], + "deberta": ["in_proj"], # "layoutlm": ["query", "value"], } diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index e2a108720b..311fa363a7 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -18,6 +18,8 @@ from parameterized import parameterized from transformers import AutoModelForCausalLM +from peft import AdaLoraConfig + from .testing_common import PeftCommonTester, PeftTestConfigManager @@ -45,6 +47,10 @@ def skip_non_pt_mqa(test_list): return [test for test in test_list if not ("prefix_tuning" in test[0] and "GPTBigCodeForCausalLM" in test[0])] +def skip_adalora_and_gpt2(test_list): + return [test for test in test_list if not (("GPT2LMHeadModel" in test[1]) and (test[2] == AdaLoraConfig))] + + class PeftDecoderModelTester(unittest.TestCase, PeftCommonTester): r""" Test if the PeftModel behaves as expected. This includes: @@ -143,8 +149,10 @@ def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs): { "model_ids": PEFT_DECODER_MODELS_TO_TEST, "lora_kwargs": {"init_lora_weights": [False]}, + "adalora_kwargs": {"init_lora_weights": [False]}, "task_type": "CAUSAL_LM", }, + filter_params_func=skip_adalora_and_gpt2, ) ) def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs): @@ -172,6 +180,7 @@ def test_training_prompt_learning_tasks(self, test_name, model_id, config_cls, c "model_ids": PEFT_DECODER_MODELS_TO_TEST, "lora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, + "adalora_kwargs": {"init_lora_weights": [False]}, "task_type": "CAUSAL_LM", }, filter_params_func=skip_non_pt_mqa, @@ -179,3 +188,13 @@ def test_training_prompt_learning_tasks(self, test_name, model_id, config_cls, c ) def test_disable_adapter(self, test_name, model_id, config_cls, config_kwargs): self._test_disable_adapter(model_id, config_cls, config_kwargs) + + def test_generate_adalora_no_dropout(self): + # test for issue #730 + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + config_kwargs = { + "target_modules": None, + "task_type": "CAUSAL_LM", + "lora_dropout": 0.0, + } + self._test_generate(model_id, AdaLoraConfig, config_kwargs) diff --git a/tests/test_encoder_decoder_models.py b/tests/test_encoder_decoder_models.py index 25d095d791..1d0b36e3a4 100644 --- a/tests/test_encoder_decoder_models.py +++ b/tests/test_encoder_decoder_models.py @@ -130,6 +130,7 @@ def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs): { "model_ids": PEFT_ENCODER_DECODER_MODELS_TO_TEST, "lora_kwargs": {"init_lora_weights": [False]}, + "adalora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, "task_type": "SEQ_2_SEQ_LM", }, @@ -159,6 +160,7 @@ def test_training_prompt_learning_tasks(self, test_name, model_id, config_cls, c { "model_ids": PEFT_ENCODER_DECODER_MODELS_TO_TEST, "lora_kwargs": {"init_lora_weights": [False]}, + "adalora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, "task_type": "SEQ_2_SEQ_LM", }, diff --git a/tests/test_feature_extraction_models.py b/tests/test_feature_extraction_models.py index c97be6e091..213063abca 100644 --- a/tests/test_feature_extraction_models.py +++ b/tests/test_feature_extraction_models.py @@ -142,6 +142,7 @@ def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs): { "model_ids": PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST, "lora_kwargs": {"init_lora_weights": [False]}, + "adalora_kwargs": {"init_lora_weights": [False]}, "task_type": "FEATURE_EXTRACTION", }, ) diff --git a/tests/testing_common.py b/tests/testing_common.py index 2e27d10601..53c9bb56ad 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -22,6 +22,7 @@ from diffusers import StableDiffusionPipeline from peft import ( + AdaLoraConfig, IA3Config, LoraConfig, PeftModel, @@ -45,10 +46,12 @@ PromptTuningConfig, ) CONFIG_TESTING_KWARGS = ( + # IA³ { "target_modules": None, "feedforward_modules": None, }, + # LoRA { "r": 8, "lora_alpha": 32, @@ -56,16 +59,23 @@ "lora_dropout": 0.05, "bias": "none", }, + # prefix tuning { "num_virtual_tokens": 10, }, + # prompt encoder { "num_virtual_tokens": 10, "encoder_hidden_size": 32, }, + # prompt tuning { "num_virtual_tokens": 10, }, + # AdaLoRA + { + "target_modules": None, + }, ) CLASSES_MAPPING = { @@ -74,6 +84,7 @@ "prefix_tuning": (PrefixTuningConfig, CONFIG_TESTING_KWARGS[2]), "prompt_encoder": (PromptEncoderConfig, CONFIG_TESTING_KWARGS[3]), "prompt_tuning": (PromptTuningConfig, CONFIG_TESTING_KWARGS[4]), + "adalora": (AdaLoraConfig, CONFIG_TESTING_KWARGS[5]), } @@ -269,6 +280,10 @@ def _test_save_pretrained(self, model_id, config_cls, config_kwargs): self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json"))) def _test_save_pretrained_selected_adapters(self, model_id, config_cls, config_kwargs): + if issubclass(config_cls, AdaLoraConfig): + # AdaLora does not support adding more than 1 adapter + return + model = self.transformers_class.from_pretrained(model_id) config = config_cls( base_model_name_or_path=model_id, @@ -640,6 +655,10 @@ def _test_training_prompt_learning_tasks(self, model_id, config_cls, config_kwar self.assertIsNotNone(param.grad) def _test_delete_adapter(self, model_id, config_cls, config_kwargs): + if issubclass(config_cls, AdaLoraConfig): + # AdaLora does not support adding more than 1 adapter + return + model = self.transformers_class.from_pretrained(model_id) config = config_cls( base_model_name_or_path=model_id, @@ -682,7 +701,7 @@ def _test_unload_adapter(self, model_id, config_cls, config_kwargs): model = get_peft_model(model, config) model = model.to(self.torch_device) - if config.peft_type not in ("LORA"): + if config.peft_type not in ("LORA", "ADALORA"): with self.assertRaises(AttributeError): model = model.unload() else: @@ -700,6 +719,10 @@ def _test_unload_adapter(self, model_id, config_cls, config_kwargs): self.assertTrue(torch.allclose(logits_transformers, logits_unload, atol=1e-4, rtol=1e-4)) def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kwargs): + if issubclass(config_cls, AdaLoraConfig): + # AdaLora does not support adding more than 1 adapter + return + adapter_list = ["adapter1", "adapter_2", "adapter_3"] weight_list = [0.5, 1.5, 1.5] model = self.transformers_class.from_pretrained(model_id)