Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for AdaLoRA, fix a few bugs #734

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/peft/tuners/adalora.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __post_init__(self):
class AdaLoraModel(LoraModel):
"""
Creates AdaLoRA (Adaptive LoRA) model from a pretrained transformers model. Paper:
https://openreview.net/pdf?id=lq62uWRJjiY
https://openreview.net/forum?id=lq62uWRJjiY

Args:
model ([`transformers.PreTrainedModel`]): The model to be adapted.
Expand Down Expand Up @@ -149,7 +149,7 @@ def _find_and_replace(self, adapter_name):
if not is_target_modules_in_base_model:
is_target_modules_in_base_model = True
parent, target, target_name = _get_submodules(self.model, key)
bias = target.bias is not None
bias = hasattr(target, "bias") and target.bias is not None
if isinstance(target, LoraLayer):
target.update_layer(
adapter_name,
Expand Down Expand Up @@ -183,6 +183,9 @@ def _find_and_replace(self, adapter_name):
new_module = SVDLinear4bit(
adapter_name, target.in_features, target.out_features, bias=bias, **fourbit_kwargs
)
elif isinstance(target, (nn.ModuleList, nn.ModuleDict)):
# it's not applicable to replace whole module lists or module dicts
continue
else:
if isinstance(target, torch.nn.Linear):
in_features, out_features = target.in_features, target.out_features
Expand Down Expand Up @@ -352,11 +355,11 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
# Actual trainable parameters
# Right singular vectors
self.lora_A.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.zeros(r, self.in_features))}))
self.lora_A.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.randn(r, self.in_features))}))
# Singular values
self.lora_E.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.zeros(r, 1))}))
self.lora_E.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.randn(r, 1))}))
# Left singular vectors
self.lora_B.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.zeros(self.out_features, r))}))
self.lora_B.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.randn(self.out_features, r))}))
# The current rank
self.ranknum.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.zeros(1), requires_grad=False)}))
self.ranknum[adapter_name].data.fill_(float(r))
Expand Down
16 changes: 9 additions & 7 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,18 +358,20 @@ def transpose(weight, fan_in_fan_out):
"t5": ["q", "k", "v", "o", "wi", "wo"],
"mt5": ["q", "k", "v", "o", "wi_0", "wi_1", "wo"],
"bart": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
# "gpt2": ["c_attn"],
# "bloom": ["query_key_value"],
"gpt2": ["c_attn"],
"bloom": ["query_key_value"],
"opt": ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"],
# "gptj": ["q_proj", "v_proj"],
# "gpt_neox": ["query_key_value"],
# "gpt_neo": ["q_proj", "v_proj"],
# "bert": ["query", "value"],
"gptj": ["q_proj", "v_proj"],
"gpt_neox": ["query_key_value"],
"gpt_neo": ["q_proj", "v_proj"],
"llama": ["q_proj", "v_proj"],
"bert": ["query", "value"],
"roberta": ["query", "key", "value", "dense"],
# "xlm-roberta": ["query", "value"],
# "electra": ["query", "value"],
"deberta-v2": ["query_proj", "key_proj", "value_proj", "dense"],
# "deberta": ["in_proj"],
"gpt_bigcode": ["c_attn"],
"deberta": ["in_proj"],
# "layoutlm": ["query", "value"],
}

Expand Down
14 changes: 14 additions & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from parameterized import parameterized
from transformers import AutoModelForCausalLM

from peft import AdaLoraConfig

from .testing_common import PeftCommonTester, PeftTestConfigManager


Expand Down Expand Up @@ -143,6 +145,7 @@ def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs):
{
"model_ids": PEFT_DECODER_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"adalora_kwargs": {"init_lora_weights": [False]},
"task_type": "CAUSAL_LM",
},
)
Expand Down Expand Up @@ -172,10 +175,21 @@ def test_training_prompt_learning_tasks(self, test_name, model_id, config_cls, c
"model_ids": PEFT_DECODER_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"ia3_kwargs": {"init_ia3_weights": [False]},
"adalora_kwargs": {"init_lora_weights": [False]},
"task_type": "CAUSAL_LM",
},
filter_params_func=skip_non_pt_mqa,
)
)
def test_disable_adapter(self, test_name, model_id, config_cls, config_kwargs):
self._test_disable_adapter(model_id, config_cls, config_kwargs)

def test_generate_adalora_no_dropout(self):
# test for issue #730
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
config_kwargs = {
"target_modules": None,
"task_type": "CAUSAL_LM",
"lora_dropout": 0.0,
}
self._test_generate(model_id, AdaLoraConfig, config_kwargs)
2 changes: 2 additions & 0 deletions tests/test_encoder_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs):
{
"model_ids": PEFT_ENCODER_DECODER_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"adalora_kwargs": {"init_lora_weights": [False]},
"ia3_kwargs": {"init_ia3_weights": [False]},
"task_type": "SEQ_2_SEQ_LM",
},
Expand Down Expand Up @@ -159,6 +160,7 @@ def test_training_prompt_learning_tasks(self, test_name, model_id, config_cls, c
{
"model_ids": PEFT_ENCODER_DECODER_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"adalora_kwargs": {"init_lora_weights": [False]},
"ia3_kwargs": {"init_ia3_weights": [False]},
"task_type": "SEQ_2_SEQ_LM",
},
Expand Down
1 change: 1 addition & 0 deletions tests/test_feature_extraction_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs):
{
"model_ids": PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
"adalora_kwargs": {"init_lora_weights": [False]},
"task_type": "FEATURE_EXTRACTION",
},
)
Expand Down
25 changes: 24 additions & 1 deletion tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from diffusers import StableDiffusionPipeline

from peft import (
AdaLoraConfig,
IA3Config,
LoraConfig,
PeftModel,
Expand All @@ -45,27 +46,36 @@
PromptTuningConfig,
)
CONFIG_TESTING_KWARGS = (
# IA³
{
"target_modules": None,
"feedforward_modules": None,
},
# LoRA
{
"r": 8,
"lora_alpha": 32,
"target_modules": None,
"lora_dropout": 0.05,
"bias": "none",
},
# prefix tuning
{
"num_virtual_tokens": 10,
},
# prompt encoder
{
"num_virtual_tokens": 10,
"encoder_hidden_size": 32,
},
# prompt tuning
{
"num_virtual_tokens": 10,
},
# AdaLoRA
{
"target_modules": None,
},
)

CLASSES_MAPPING = {
Expand All @@ -74,6 +84,7 @@
"prefix_tuning": (PrefixTuningConfig, CONFIG_TESTING_KWARGS[2]),
"prompt_encoder": (PromptEncoderConfig, CONFIG_TESTING_KWARGS[3]),
"prompt_tuning": (PromptTuningConfig, CONFIG_TESTING_KWARGS[4]),
"adalora": (AdaLoraConfig, CONFIG_TESTING_KWARGS[5]),
}


Expand Down Expand Up @@ -269,6 +280,10 @@ def _test_save_pretrained(self, model_id, config_cls, config_kwargs):
self.assertFalse(os.path.exists(os.path.join(tmp_dirname, "config.json")))

def _test_save_pretrained_selected_adapters(self, model_id, config_cls, config_kwargs):
if issubclass(config_cls, AdaLoraConfig):
# AdaLora does not support adding more than 1 adapter
return

model = self.transformers_class.from_pretrained(model_id)
config = config_cls(
base_model_name_or_path=model_id,
Expand Down Expand Up @@ -640,6 +655,10 @@ def _test_training_prompt_learning_tasks(self, model_id, config_cls, config_kwar
self.assertIsNotNone(param.grad)

def _test_delete_adapter(self, model_id, config_cls, config_kwargs):
if issubclass(config_cls, AdaLoraConfig):
# AdaLora does not support adding more than 1 adapter
return

model = self.transformers_class.from_pretrained(model_id)
config = config_cls(
base_model_name_or_path=model_id,
Expand Down Expand Up @@ -682,7 +701,7 @@ def _test_unload_adapter(self, model_id, config_cls, config_kwargs):
model = get_peft_model(model, config)
model = model.to(self.torch_device)

if config.peft_type not in ("LORA"):
if config.peft_type not in ("LORA", "ADALORA"):
with self.assertRaises(AttributeError):
model = model.unload()
else:
Expand All @@ -700,6 +719,10 @@ def _test_unload_adapter(self, model_id, config_cls, config_kwargs):
self.assertTrue(torch.allclose(logits_transformers, logits_unload, atol=1e-4, rtol=1e-4))

def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kwargs):
if issubclass(config_cls, AdaLoraConfig):
# AdaLora does not support adding more than 1 adapter
return

adapter_list = ["adapter1", "adapter_2", "adapter_3"]
weight_list = [0.5, 1.5, 1.5]
model = self.transformers_class.from_pretrained(model_id)
Expand Down