-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merge lora module to 8bit model #875
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for keeping up with your good work and delivering the 8bit merging feature.
As you can see, there is a merge conflict now after we moved the tuners to sub-packages in #807. Don't worry, it is easy to fix:
Your change that starts with elif is_bnb_available() and ...
should be moved to https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/model.py.
Your change that starts with def merge(self):
should be moved to https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/bnb.py.
The test can stay the same.
On top of that, after merging #851, we found a small bug in your previous PR. The explanation is contained here. Please take a look so that we can ensure that the same thing does not happen with the 8bit layer. As you can see, we also worked on the test to make it a little more precise by checking probabilities and not tokens. I think this should work for this PR too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very interesting approach!
Thanks for your help. I have fixed the 8bit linear forward to support merge and disable adapter, but I have a question about the 4bit linear forward, see this. I wonder why the result need to add |
Thanks for reporting, this is indeed a bug introduced -- by me :-/ -- recently. #878 should provide a fix. |
I did some scatter plots again (this time probas) and compared between the newly added 8bit merge and the previously added 4bit merge: Now, the outputs are much closer for 8bit than for 4bit, contrary to what we saw earlier and in accordance with expectations. So for this example, the 8bit merge looks really good. It looks so good, in fact, that the tests locally pass for me even with |
Thanks for your work. Do @younesbelkada and @pacman100 have any comments? Would like to hear your opinion. |
Hi @HuggingFaceDocBuilderDev @pacman100 @younesbelkada . I hope I can get your opinion about this PR. If nothing needs to be changed, could we merge this PR? Thanks! |
@jiqing-feng Sorry for the delay. There is some further feedback we're waiting for regarding your PR, this should hopefully arrive by the end of this week or start of next week. |
@jiqing-feng Sorry for the delay. We think the changes are good and can be merged. Could you please fix the merge conflict? Thanks a lot for your patience. |
Done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for updating the PR. I gave this a final review and found two small issues (sorry for not noticing earlier), could you please take a look?
src/peft/tuners/lora/bnb.py
Outdated
|
||
if self.state.SCB is None: | ||
self.state.SCB = self.weight.SCB | ||
# Dequantize the result of identify matrix and int8 weight because bitsandbytes only have this method. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Dequantize the result of identify matrix and int8 weight because bitsandbytes only have this method. | |
# Dequantize the result of identity matrix and int8 weight because bitsandbytes does not support int8 | |
# dequantization directly |
self.state.reset_grads() | ||
self.merged = True | ||
|
||
def unmerge(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please extend the test, or add a separate one, to also test unmerging?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think test_8bit_merge_and_disable_lora
in test_common_gpu.py
is for testing unmerge since model.disable_adapter()
will call unmerge()
in the forward.
Do I misunderstand it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right, sorry I missed that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np.
I have fixed the annotation and removed "default"
in merge_and_unload()
in the tests, would you please help me review these? Thx!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @jiqing-feng for all the work on this, LGTM! 🚀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fantastic addition, big thanks.
Hi @younesbelkada @pacman100 @BenjaminBossan @TimDettmers .
Relate to 851. I found a way to merge 8bit model.
BitsandBytes can only dequantize the int MatMul result. Therefore, I use an identify matrix to multiply the 8-bit weight, so the result is equal to the original weight after dequantization.
The test script is also attached
The result should be
We can see that the difference between the two logits is really small.
Would you please help me review it? Thx!