Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Issue with using rank_pattern and alpha_pattern together in LoraConfig #2194

Closed
2 of 4 tasks
sirluk opened this issue Nov 1, 2024 · 2 comments
Closed
2 of 4 tasks

Comments

@sirluk
Copy link
Contributor

sirluk commented Nov 1, 2024

System Info

peft==0.13.2

Who can help?

@BenjaminBossan

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

Threre is an issue in the _create_and_replace method of LoraModel when rank_pattern and alpha_pattern are both set in LoraConfig.

The issue is relating to this line which merges keys from rank_pattern and alpha_pattern to get the target_name_key to retrieve from both dicts.

  # Regexp matching - Find key which matches current target_name in patterns provided
  pattern_keys = list(chain(lora_config.rank_pattern.keys(), lora_config.alpha_pattern.keys()))
  target_name_key = next(filter(lambda key: re.match(rf".*\.{key}$", current_key), pattern_keys), current_key)
  r = lora_config.rank_pattern.get(target_name_key, lora_config.r)
  alpha = lora_config.alpha_pattern.get(target_name_key, lora_config.lora_alpha)

If for example rank_pattern is defined with a more general substring (i.e. matching more layers) than alpha_pattern the appropriate value from alpha_pattern is not retrieved. I assume it is the other way around as well.

You can find a minimal example to reproduce the issue here:

from peft import get_peft_model, LoraConfig
from transformers import AutoModelForCausalLM


model_name = "openai-community/gpt2"
target_modules = ["c_attn"]

model = AutoModelForCausalLM.from_pretrained(model_name)
peft_config = LoraConfig(
    r=8,
    lora_alpha=1,
    target_modules=target_modules,
    alpha_pattern={"h.8.attn.c_attn": 2},
    rank_pattern={"c_attn": 8},
)
peft_model = get_peft_model(model, peft_config)
print("scaling (expected alpha/r=0.25):", peft_model.model.transformer.h[8].attn.c_attn.scaling)

model = AutoModelForCausalLM.from_pretrained(model_name)
peft_config = LoraConfig(
    r=8,
    lora_alpha=1,
    target_modules=target_modules,
    alpha_pattern={"h.8.attn.c_attn": 2},
)
peft_model = get_peft_model(model, peft_config)
print("scaling (expected alpha/r=0.25):", peft_model.model.transformer.h[8].attn.c_attn.scaling)

Expected behavior

I would expect both print statements to print the same value 2/8=0.25. However we see that when rank_pattern is defined, layer "h.8.attn.c_attn" is not assigned the correct scaling value. I would have two suggestions to fix this:

  1. Just run next(filter(lambda key: re.match(rf".*\.{key}$", current_key), pattern_keys), current_key)twice, once for rank_pattern keys and once for alpha_pattern keys.
  2. Harmonize both dicts in the post_init method of LoraConfig to ensure their keys are both defined at the same granularity.
@BenjaminBossan
Copy link
Member

Thanks a lot for investigating this bug, providing a reproducer, and even 2 suggestions to solve this. Option 1 reads like the simpler solution to me. LMK if you are interested in providing a PR to fix this.

@sirluk
Copy link
Contributor Author

sirluk commented Nov 1, 2024

thanks for the feedback! Sounds good, I will create a PR for option 1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants