Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge LoRA Adapter with int8 base model. #638

Closed
jenkspt opened this issue Jun 26, 2023 · 15 comments · Fixed by #851
Closed

Merge LoRA Adapter with int8 base model. #638

jenkspt opened this issue Jun 26, 2023 · 15 comments · Fixed by #851

Comments

@jenkspt
Copy link

jenkspt commented Jun 26, 2023

Feature request

Support merging LoRA adapters with base model when base model is loaded in int8.

Motivation

  • This is helpful when the goal is to merge adapter weights for faster inference with 8bit model inference.
  • This is helpful for low memory environments when it may not be possible to load the model in half precision before merging.

Your contribution

Happy to create this PR. Any insights to avoid pitfalls are welcome.

@jenkspt
Copy link
Author

jenkspt commented Jun 26, 2023

Relevant code: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora.py#L417-L418

@pacman100
Copy link
Contributor

Hello @jenkspt, the bottleneck here is that the lora weights are in float32 and the quantized models have weights in int8/NF4. So, it will involve quite some effort to support when a simple workaround exists. Below is the workaround

model = AutoModelForXXXXX.from_pretrained()
model = PeftModel.from_pretrained(model, peft_model_id)

model = model.merge_and_unload()

model.save_pretrained("merged_model")
model = AutoModelForXXXXX.from_pretrained("merged_model", load_in_8bit=True)

# do inference

cc @younesbelkada for adding more context

@younesbelkada
Copy link
Contributor

Hi @jenkspt
Yes I second what @pacman100 said, you need to first load the fp16/bf16 standalone model, merge it and load it back in 8bit/4bit to make that work.

@jenkspt
Copy link
Author

jenkspt commented Jun 27, 2023

@pacman100 can you describe why this is difficult to implement?

Using falcon-40b with LoRA adapters, the workaround takes an unreasonable amount of time on the CPU, and doesn't work with limited GPU memory (i.e. only enough to support the 8bit model).

@younesbelkada
Copy link
Contributor

@jenkspt
Technically it is doable, one needs to dequantize the int8 weights to fp16 on the fly for each 8bit bnb layer, however it will lead to numerical differences that can ultimately leading to not getting the same results as the original fp16/bf16 weights, hence the recommendation

@jenkspt
Copy link
Author

jenkspt commented Jun 28, 2023

Thanks for the info @younesbelkada. Since I've been fine-tuning with int8 + LoRA, there shouldn't be any numerical differences. Can you point me at any functions or resources for de-quantization? Happy to create a PR for this.

@younesbelkada
Copy link
Contributor

Sure @jenkspt , you can check out this specific section of the bnb integration blogpost: https://huggingface.co/blog/hf-bitsandbytes-integration#usage
TLDR is that you need to loop over all Linear8bitLt layers, and apply the following operation:

(layer.weight.CB * layer.weight.SCB) / 127

It would be great if you can open a PR for that! Let us know if you need any help

@jtrechot
Copy link

If there is any PR about that issue @jenkspt , could you put the link here as well?

Thank you! :D

@jenkspt
Copy link
Author

jenkspt commented Jul 10, 2023

@jtrechot haven't gotten to it yet.

@github-actions
Copy link

github-actions bot commented Aug 4, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@jenkspt
Copy link
Author

jenkspt commented Aug 9, 2023

I think this is still a desired feature -- I simply don't have the time to implement it right now.

@younesbelkada
Copy link
Contributor

With #851 you can now call merge_and_unload on 4bit models, however keeping the issue open for 8bit models, please refer to: #851 (comment)

@BenjaminBossan
Copy link
Member

#875 might provide the solution for 8bit.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@younesbelkada
Copy link
Contributor

It is now supported! Please install PEFT from source , closing the issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants