Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoraConfig conflict when using layers_to_transform in LlamaModel #2155

Open
2 of 4 tasks
Evan02580 opened this issue Oct 17, 2024 · 4 comments
Open
2 of 4 tasks

LoraConfig conflict when using layers_to_transform in LlamaModel #2155

Evan02580 opened this issue Oct 17, 2024 · 4 comments

Comments

@Evan02580
Copy link

System Info

peft: 0.13.2
transformers: 4.43.1

Who can help?

@BenjaminBossan @sayakpaul

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

When I tried to use LoraConfig and aimed to apply lora in first and last layers like:

lora_config = LoraConfig(
    r = 8,
    lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    layers_to_transform=[0,31],
    lora_dropout=0,
    bias = "none",
)
model = LlamaModel.from_pretrained("meta-llama/Meta-Llama-3-8B", torch_dtype=torch.bfloat16)
llama_model = get_peft_model(model, lora_config)

It came the problem that:

*** ValueError: Target modules ['q_proj', 'k_proj', 'v_proj', 'o_proj'] not found in the base model. Please check the target modules and try again.

The similar thing happen if I use layers_pattern instead of target_modules (but it should be my misunderstanding of layers_pattern):

lora_config = LoraConfig(
    ...
    layers_to_transform = 1, 
    layers_pattern = ["q_proj", "k_proj", "v_proj", "o_proj"], 
    ...
)
get_peft_model(model, lora_config)
*** ValueError: Target modules {'v_proj', 'q_proj'} not found in the base model. Please check the target modules and try again.

But this time the problem shoud be the problem of default value of target_modules.

However, when I use model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B", torch_dtype=torch.bfloat16, trust_remote_code=True) instead, it made it.

Expected behavior

I'm not sure if it was the problem of LlamaModel. And I do also confuse about the use of layers_patten, since of doc of LoRA mentioned:

  • layers_to_transform: List of layers to be transformed by LoRA. If not specified, all layers in target_modules are transformed.
  • layers_pattern: Pattern to match layer names in target_modules, if layers_to_transform is specified. By default PeftModel will look at common layer pattern (layers, h, blocks, etc.), use it for exotic and custom models.

It should work with layers_to_transform, however, I didn'd find a suitable approach to use. Maybe some examples can be put in class LoraConfig(PeftConfig)?

@Evan02580 Evan02580 changed the title LoraConfig conflict when using layers_to_transform LoraConfig conflict when using layers_to_transform in LlamaModel Oct 17, 2024
@BenjaminBossan
Copy link
Member

Thanks for reporting the issue. Indeed, the usage of layers_to_transform and layers_pattern is a bit confusing and the error message is not helpful.

The idea here is that if we have a nn.ModuleList with 32 layers in this case, the layers_pattern should designate this nn.ModuleList: layers_pattern="layers". Therefore, this works for me:

lora_config = LoraConfig(
    r = 8,
    lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj"],
    layers_to_transform=[0, 31],
    layers_pattern="layers",
    lora_dropout=0,
    bias = "none",
)

However, as you noted, using LlamaModel directly does not work. This is a result of how we specify a regex and I think we can amend it to work with LlamaModel too. So for now, please use AutoModelForCausalLM with the LoraConfig I showed and you should be good.

The TODOs from this issue are:

  1. Improve the documentation of these arguments to clarify what users need to pass.
  2. Amend the regex to make the prefix before the layers_pattern optional.
  3. Adjust the error message for the case that users pass layers_to_transform and layers_pattern (right now, the error message assumes that users only pass target_modules.

For point 3, would you be interested in tackling this @JINO-ROHIT since you refactored that part in #2102?

@JINO-ROHIT
Copy link
Contributor

@BenjaminBossan yeap il be happy to work on this

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this issue Oct 17, 2024
Addresses part of huggingface#2155.

Also fix type annotations where appropriate.
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this issue Oct 17, 2024
Addreses part of huggingface#2155.

Description

So far, the layers_pattern argument would only work if there was a
prefix to the pattern. As an example, if the module name is:

decoder.layer.0.attn.to_q

and we pass layers_pattern="layer", this would match. However, if the
module name was:

layer.0.attn.to_q

it would not work.

Usually, when we create a model with AutoModelForFoo.from_pretrained,
the "layer" part would never be first. However, if we load a model
directly, e.g. through LlamaModel.from_pretrained, there is actually no
prefix. As a consequence, we get no match there.

With this PR, the prefix is made optional, so that the second pattern
also matches.

Status

I'm not sure yet if this should be merged, as it is technically
backwards incompatible. Users can still target the desired modules by
carefully crafting a regex for target_modules so that it only matches
the desired layer indices. However, this is tedious and layers_pattern
was introduced to avoid having to do this.
@BenjaminBossan
Copy link
Member

@Evan02580 I created a PR to improve the docs in #2157 and another PR to adapt the regex in #2158. For the latter, I'm unsure if we should proceed though, as technically this is a backwards-incompatible change.

BenjaminBossan added a commit that referenced this issue Oct 18, 2024
Addresses part of #2155.

Also fix type annotations where appropriate.
sirluk pushed a commit to sirluk/peft that referenced this issue Oct 19, 2024
Addresses part of huggingface#2155.

Also fix type annotations where appropriate.
yaswanth19 pushed a commit to yaswanth19/peft that referenced this issue Oct 20, 2024
Addresses part of huggingface#2155.

Also fix type annotations where appropriate.
yaswanth19 pushed a commit to yaswanth19/peft that referenced this issue Oct 20, 2024
Addresses part of huggingface#2155.

Also fix type annotations where appropriate.
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants