Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge lora module to 8bit model #875

Merged
merged 11 commits into from
Sep 7, 2023

Conversation

jiqing-feng
Copy link
Contributor

@jiqing-feng jiqing-feng commented Aug 29, 2023

Hi @younesbelkada @pacman100 @BenjaminBossan @TimDettmers .

Relate to 851. I found a way to merge 8bit model.

BitsandBytes can only dequantize the int MatMul result. Therefore, I use an identify matrix to multiply the 8-bit weight, so the result is equal to the original weight after dequantization.

The test script is also attached

torch.manual_seed(1024)
model_origin = AutoModelForCausalLM.from_pretrained(
            "facebook/opt-125m",)
model = AutoModelForCausalLM.from_pretrained(
            "facebook/opt-125m",
            load_in_8bit=True,
        )
model = prepare_model_for_kbit_training(model)
random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)
print("original model outputs")
print(model_origin(random_input.clone().to(model_origin.device)).logits)

config = LoraConfig(
    r=8,
    init_lora_weights=False,
    target_modules=["k_proj", "v_proj", "q_proj", "out_proj", "fc1", "fc2"]
)
model = get_peft_model(model, config)
with torch.inference_mode():
    out_before_merge = model(random_input)
    print("out before merge")
    print(out_before_merge.logits)

model = model.merge_and_unload("default")
with torch.inference_mode():
    out_after_merge = model(random_input)
    print("out after merge")
    print(out_after_merge.logits)

The result should be

original model outputs
tensor([[[-3.9464, -3.9443,  3.2428,  ..., -3.9583, -3.9531, -4.0695],
         [ 2.1592,  2.1657,  3.4937,  ...,  2.1609,  2.2064,  1.8032],
         [ 1.7407,  1.7366,  3.3675,  ...,  1.7371,  1.7456,  1.5533],
         [ 1.7655,  1.7583,  3.1950,  ...,  1.7668,  1.7512,  1.5889],
         [ 1.8861,  1.8778,  3.0723,  ...,  1.8935,  1.8632,  1.7100],
         [ 2.0078,  1.9973,  2.9639,  ...,  2.0184,  1.9726,  1.8311]]],
       grad_fn=<UnsafeViewBackward0>)

out before merge
tensor([[[-4.8259, -4.8183,  0.9898,  ..., -4.9119, -4.9100, -4.8111],
         [-3.5888, -3.5779,  0.9716,  ..., -3.7284, -3.6713, -3.7122],
         [-3.6172, -3.6078,  0.8687,  ..., -3.7570, -3.6960, -3.7456],
         [-3.5690, -3.5609,  0.7098,  ..., -3.7097, -3.6471, -3.6961],
         [-3.6113, -3.6041,  0.6201,  ..., -3.7497, -3.6936, -3.7288],
         [-3.7012, -3.6957,  0.5803,  ..., -3.8397, -3.7881, -3.8144]]],
       device='cuda:0')

out after merge
tensor([[[-4.7348, -4.7283,  1.1155,  ..., -4.8248, -4.8206, -4.7260],
         [-3.5906, -3.5802,  0.9311,  ..., -3.7292, -3.6703, -3.7129],
         [-3.6750, -3.6668,  0.8512,  ..., -3.8123, -3.7480, -3.7943],
         [-3.6069, -3.5996,  0.7566,  ..., -3.7426, -3.6845, -3.7268],
         [-3.6024, -3.5957,  0.7001,  ..., -3.7365, -3.6801, -3.7149],
         [-3.7372, -3.7322,  0.5553,  ..., -3.8741, -3.8224, -3.8415]]],
       device='cuda:0')

We can see that the difference between the two logits is really small.

Would you please help me review it? Thx!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for keeping up with your good work and delivering the 8bit merging feature.

As you can see, there is a merge conflict now after we moved the tuners to sub-packages in #807. Don't worry, it is easy to fix:

Your change that starts with elif is_bnb_available() and ... should be moved to https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/model.py.

Your change that starts with def merge(self): should be moved to https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/bnb.py.

The test can stay the same.

On top of that, after merging #851, we found a small bug in your previous PR. The explanation is contained here. Please take a look so that we can ensure that the same thing does not happen with the 8bit layer. As you can see, we also worked on the test to make it a little more precise by checking probabilities and not tokens. I think this should work for this PR too.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very interesting approach!

@jiqing-feng
Copy link
Contributor Author

Thanks a lot for keeping up with your good work and delivering the 8bit merging feature.

As you can see, there is a merge conflict now after we moved the tuners to sub-packages in #807. Don't worry, it is easy to fix:

Your change that starts with elif is_bnb_available() and ... should be moved to https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/model.py.

Your change that starts with def merge(self): should be moved to https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/bnb.py.

The test can stay the same.

On top of that, after merging #851, we found a small bug in your previous PR. The explanation is contained here. Please take a look so that we can ensure that the same thing does not happen with the 8bit layer. As you can see, we also worked on the test to make it a little more precise by checking probabilities and not tokens. I think this should work for this PR too.

Thanks for your help.

I have fixed the 8bit linear forward to support merge and disable adapter, but I have a question about the 4bit linear forward, see this. I wonder why the result need to add lora_B(lora_A(dropout(x))) * scaling twice?

@BenjaminBossan
Copy link
Member

I wonder why the result need to add lora_B(lora_A(dropout(x))) * scaling twice?

Thanks for reporting, this is indeed a bug introduced -- by me :-/ -- recently. #878 should provide a fix.

@BenjaminBossan
Copy link
Member

I did some scatter plots again (this time probas) and compared between the newly added 8bit merge and the previously added 4bit merge:

Screenshot from 2023-08-29 16-36-11
Screenshot from 2023-08-29 16-29-01

Now, the outputs are much closer for 8bit than for 4bit, contrary to what we saw earlier and in accordance with expectations. So for this example, the 8bit merge looks really good. It looks so good, in fact, that the tests locally pass for me even with atol=0.001 and rtol=0.3.

@jiqing-feng
Copy link
Contributor Author

I did some scatter plots again (this time probas) and compared between the newly added 8bit merge and the previously added 4bit merge:

Screenshot from 2023-08-29 16-36-11 Screenshot from 2023-08-29 16-29-01

Now, the outputs are much closer for 8bit than for 4bit, contrary to what we saw earlier and in accordance with expectations. So for this example, the 8bit merge looks really good. It looks so good, in fact, that the tests locally pass for me even with atol=0.001 and rtol=0.3.

Thanks for your work. Do @younesbelkada and @pacman100 have any comments? Would like to hear your opinion.

@jiqing-feng
Copy link
Contributor Author

Hi @HuggingFaceDocBuilderDev @pacman100 @younesbelkada .

I hope I can get your opinion about this PR. If nothing needs to be changed, could we merge this PR?

Thanks!

@BenjaminBossan
Copy link
Member

@jiqing-feng Sorry for the delay. There is some further feedback we're waiting for regarding your PR, this should hopefully arrive by the end of this week or start of next week.

@BenjaminBossan
Copy link
Member

@jiqing-feng Sorry for the delay. We think the changes are good and can be merged. Could you please fix the merge conflict? Thanks a lot for your patience.

@jiqing-feng
Copy link
Contributor Author

@jiqing-feng Sorry for the delay. We think the changes are good and can be merged. Could you please fix the merge conflict? Thanks a lot for your patience.

Done.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for updating the PR. I gave this a final review and found two small issues (sorry for not noticing earlier), could you please take a look?


if self.state.SCB is None:
self.state.SCB = self.weight.SCB
# Dequantize the result of identify matrix and int8 weight because bitsandbytes only have this method.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Dequantize the result of identify matrix and int8 weight because bitsandbytes only have this method.
# Dequantize the result of identity matrix and int8 weight because bitsandbytes does not support int8
# dequantization directly

self.state.reset_grads()
self.merged = True

def unmerge(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please extend the test, or add a separate one, to also test unmerging?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think test_8bit_merge_and_disable_lora in test_common_gpu.py is for testing unmerge since model.disable_adapter() will call unmerge() in the forward.
Do I misunderstand it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right, sorry I missed that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.

I have fixed the annotation and removed "default" in merge_and_unload() in the tests, would you please help me review these? Thx!

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @jiqing-feng for all the work on this, LGTM! 🚀

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic addition, big thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants