Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Multiple adapters with bnb layers #1243

Conversation

BenjaminBossan
Copy link
Member

Resolves #1239

Description

Fixes a bug that led to an error when loading multiple adapters into a
peft model that uses bnb layers.

Notes

I tested this locally on GPU and tests pass for me.

Implementation

While working on this, I thought we could make a simplification to our
layer code. Right now, we have code like this inside of
_create_and_replace:

    if isinstance(target, conv2d):
        target.update_layer_conv2d(...)
    elif isinstance(target, embedding):
        target.update_layer_embedding(...)
    elif ...

If we move the update_layer* methods from LoraLayer to the subclasses,
i.e. Linear, Embedding, etc., and give them all the same name (i.e.
update_layer_embedding => update_layer), we could simplify the code
above to:

    if isinstance(target, LoraLayer):
        target.update_layer(...)

Additionally, it makes more sense to me that the code for
update_layer_embedding lives inside of Embedding, not the parent class.
If you agree, I could work on that in a separate PR.

Resolves huggingface#1239

Description

Fixes a bug that led to an error when loading multiple adapters into a
peft model that uses bnb layers.

Notes

I tested this locally on GPU and tests pass for me.

Implementation

While working on this, I thought we could make a simplification to our
layer code. Right now, we have code like this inside of
_create_and_replace:

    if isinstance(target, conv2d):
        target.update_layer_conv2d(...)
    elif isinstance(target, embedding):
        target.update_layer_embedding(...)
    elif ...

If we move the update_layer methods from LoraLayer to the subclasses,
i.e. Linear, Embedding, etc., and give them all the same name (i.e.
update_layer_embedding => update_layer), we could simplify the code
above to:

    if isinstance(target, LoraLayer):
        target.update_layer(...)

Additionally, it makes more sense to me that the code for
update_layer_embedding lives inside of Embedding, not the parent class.
If you agree, I could work on that in a separate PR.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you !

@chris111110
Copy link

while I tried to use this solution to load multiple lora in peft model, it appear a error.

245 parent, target, target_name = _get_submodules(model, key)
247 optional_kwargs = {
248 "loaded_in_8bit": getattr(model, "is_loaded_in_8bit", False),
249 "loaded_in_4bit": getattr(model, "is_loaded_in_4bit", False),
250 "current_key": key,
...
1672 buffers[name] = value
1673 else:
-> 1674 super().setattr(name, value)

AttributeError: can't set attribute 'weight'

and it seems to work if I modify the code in _create_new_module this way:

elif AutoGPTQQuantLinear is not None and isinstance(target_base_layer, AutoGPTQQuantLinear):
new_module = QuantLinear(target, adapter_name, **kwargs)
###xlc
if target == None:
target.weight = target_base_layer.qweight
else:
target = target_base_layer
target.weight = target_base_layer.qweight

@BenjaminBossan
Copy link
Member Author

@ChrisXULC Thanks for reporting. Do you have some code to reproduce the error?

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @BenjaminBossan for fixing the issue when using multiple adapters with bnb layers, nice tests! LGTM! 🚀

@chris111110
Copy link

@ChrisXULC Thanks for reporting. Do you have some code to reproduce the error?

@BenjaminBossan it could be reproduce while trying to load multiple loras into Qwen 7b int 4 model. before loading loras, I have modified the anaconda3/envs/lib/python3.10/site-packages/peft/tuners/lora/model.py like you did.

model = AutoModelForCausalLM.from_pretrained(
"/home/models/qwen/Qwen-7B-Chat-Int4",
device_map="auto",
trust_remote_code=True
# use_cache_quantization=True,
# use_cache_kernel=True,
# use_flash_attn=False
).eval()

model = PeftModel.from_pretrained(model,'/home//models/quantized_model_lora', adapter_name="23",device_map="cuda", load_in_4bit=True)
model.load_adapter('/home/lora_all/quantized_model_vul',adapter_name = '22')

@BenjaminBossan
Copy link
Member Author

BenjaminBossan commented Dec 12, 2023

@ChrisXULC Thanks a lot. The issue should be fixed with the latest commit. I omitted the if target == None: guard in from your suggestion, I don't think that can happen, but LMK if I'm wrong.

(Note: tests are failing because of #1252)

@chris111110
Copy link

@BenjaminBossan I have changed it to target.qweight=target_base_layer.qweight,I think it should be ok

@BenjaminBossan
Copy link
Member Author

@ChrisXULC Thanks for checking. If you also have time to check PEFT from this branch, that would be fantastic.

@chris111110
Copy link

@BenjaminBossan cool!

@BenjaminBossan
Copy link
Member Author

Tests are green now after merging #1252, I'll merge this one too.

@BenjaminBossan BenjaminBossan merged commit 971dd6e into huggingface:main Dec 12, 2023
@BenjaminBossan BenjaminBossan deleted the bugfix-multiple-adapters-with-bnb branch December 12, 2023 14:34
@BenjaminBossan BenjaminBossan mentioned this pull request Dec 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

error while loading lora
5 participants