Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed LoraConfig alpha modification on add_weighted_adapter #654

Merged
merged 3 commits into from
Jul 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/peft/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import math
import re
import warnings
from dataclasses import asdict, dataclass, field
from dataclasses import asdict, dataclass, field, replace
from enum import Enum
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -450,8 +450,9 @@ def merge_and_unload(self):
def add_weighted_adapter(self, adapters, weights, adapter_name):
if len({self.peft_config[adapter].r for adapter in adapters}) != 1:
raise ValueError("All adapters must have the same r value")
self.peft_config[adapter_name] = self.peft_config[adapters[0]]
self.peft_config[adapter_name].lora_alpha = self.peft_config[adapters[0]].r
self.peft_config[adapter_name] = replace(
self.peft_config[adapters[0]], lora_alpha=self.peft_config[adapters[0]].r
)
self._find_and_replace(adapter_name)
mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias)
_freeze_adapter(self.model, adapter_name)
Expand Down
69 changes: 52 additions & 17 deletions tests/test_stablediffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import asdict, replace
from unittest import TestCase

import numpy as np
Expand All @@ -30,14 +31,14 @@
{
"text_encoder": {
"r": 8,
"lora_alpha": 8,
"lora_alpha": 32,
"target_modules": ["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"],
"lora_dropout": 0.0,
"bias": "none",
},
"unet": {
"r": 8,
"lora_alpha": 8,
"lora_alpha": 32,
"target_modules": ["proj_in", "proj_out", "to_k", "to_q", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"],
"lora_dropout": 0.0,
"bias": "none",
Expand All @@ -59,21 +60,7 @@ class StableDiffusionModelTester(TestCase, PeftCommonTester):
"""
transformers_class = StableDiffusionPipeline

def prepare_inputs_for_testing(self):
return {
"prompt": "a high quality digital photo of a cute corgi",
"num_inference_steps": 20,
}

@parameterized.expand(
PeftStableDiffusionTestConfigManager.get_grid_parameters(
{
"model_ids": PEFT_DIFFUSERS_SD_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
},
)
)
def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):
def instantiate_sd_peft(self, model_id, config_cls, config_kwargs):
# Instantiate StableDiffusionPipeline
model = self.transformers_class.from_pretrained(model_id)

Expand All @@ -92,6 +79,26 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):
# Move model to device
model = model.to(self.torch_device)

return model

def prepare_inputs_for_testing(self):
return {
"prompt": "a high quality digital photo of a cute corgi",
"num_inference_steps": 20,
}

@parameterized.expand(
PeftStableDiffusionTestConfigManager.get_grid_parameters(
{
"model_ids": PEFT_DIFFUSERS_SD_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
},
)
)
def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):
# Instantiate model & adapters
model = self.instantiate_sd_peft(model_id, config_cls, config_kwargs)

# Generate output for peft modified StableDiffusion
dummy_input = self.prepare_inputs_for_testing()
with temp_seed(seed=42):
Expand All @@ -107,3 +114,31 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs):

# Images are in uint8 drange, so use large atol
self.assertTrue(np.allclose(peft_output, merged_output, atol=1.0))

@parameterized.expand(
PeftStableDiffusionTestConfigManager.get_grid_parameters(
{
"model_ids": PEFT_DIFFUSERS_SD_MODELS_TO_TEST,
"lora_kwargs": {"init_lora_weights": [False]},
},
)
)
def test_add_weighted_adapter_base_unchanged(self, test_name, model_id, config_cls, config_kwargs):
# Instantiate model & adapters
model = self.instantiate_sd_peft(model_id, config_cls, config_kwargs)

# Get current available adapter config
text_encoder_adapter_name = next(iter(model.text_encoder.peft_config.keys()))
unet_adapter_name = next(iter(model.unet.peft_config.keys()))
text_encoder_adapter_config = replace(model.text_encoder.peft_config[text_encoder_adapter_name])
unet_adapter_config = replace(model.unet.peft_config[unet_adapter_name])

# Create weighted adapters
model.text_encoder.add_weighted_adapter([unet_adapter_name], [0.5], "weighted_adapter_test")
model.unet.add_weighted_adapter([unet_adapter_name], [0.5], "weighted_adapter_test")

# Assert that base adapters config did not change
self.assertTrue(
asdict(text_encoder_adapter_config) == asdict(model.text_encoder.peft_config[text_encoder_adapter_name])
)
self.assertTrue(asdict(unet_adapter_config) == asdict(model.unet.peft_config[unet_adapter_name]))