Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support merge lora module for 4bit and 8bit linear #851

Merged
merged 13 commits into from
Aug 28, 2023

Conversation

jiqing-feng
Copy link
Contributor

@jiqing-feng jiqing-feng commented Aug 23, 2023

Hi @younesbelkada @pacman100 @TimDettmers

This PR enables merging lora module for 4-bit and 8-bit linear layers in bitsandbytes.
It dequantizes the k-bit parameter to float type, adds lora module to the float parameter, and then quantizes the new parameter.

Fixes #638

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great ! Thanks for your great work on this, before merging, can you add 2 tests in this file: https://github.com/huggingface/peft/blob/main/tests/test_common_gpu.py one for 8bit models and another for 4bit models. For testing purpose you can use a small model such as facebook/opt-350m .

@jiqing-feng
Copy link
Contributor Author

This looks great ! Thanks for your great work on this, before merging, can you add 2 tests in this file: https://github.com/huggingface/peft/blob/main/tests/test_common_gpu.py one for 8bit models and another for 4bit models. For testing purpose you can use a small model such as facebook/opt-350m .

Hi @younesbelkada , I have added the test for merging lora module to 8bit and 4bit model.

Would you please help to review it? Thx!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! Added a comment, can you add a logits or generations check inside the tests you have designed ? 🙏 Apart from that it looks quite nice!

self.assertTrue(
isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, bnb.nn.Linear8bitLt)
)
self.assertTrue(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a logits or generation test for that ? you can check the logits of the model before and after merging and make sure they are the same. Make sure to add init_lora_weights=False when creating the LoraConfig to avoid initializing the B matrix with full zeros: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora.py#L98

Copy link
Contributor Author

@jiqing-feng jiqing-feng Aug 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @younesbelkada
It is hard to keep the result of forward logits or generation text the same because both dequantization and quantization have precision loss. You can try the following script

import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

torch.manual_seed(1024)
model_origin = AutoModelForCausalLM.from_pretrained(
                "opt-125m",
                torch_dtype=torch.float32,
            )
bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_type=torch.float32
        )
model = AutoModelForCausalLM.from_pretrained(
            "opt-125m",
            quantization_config=bnb_config,
            torch_dtype=torch.float32,
        )
random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)
print(model_origin(random_input.clone().to(model_origin.device)).logits)
print(model(random_input).logits)

The output should be

tensor([[[-3.9464, -3.9443,  3.2428,  ..., -3.9583, -3.9531, -4.0695],
         [ 2.1592,  2.1657,  3.4937,  ...,  2.1609,  2.2064,  1.8032],
         [ 1.7407,  1.7366,  3.3675,  ...,  1.7371,  1.7456,  1.5533],
         [ 1.7655,  1.7583,  3.1950,  ...,  1.7668,  1.7512,  1.5889],
         [ 1.8861,  1.8778,  3.0723,  ...,  1.8935,  1.8632,  1.7100],
         [ 2.0078,  1.9973,  2.9639,  ...,  2.0184,  1.9726,  1.8311]]],
       grad_fn=<UnsafeViewBackward0>)
tensor([[[-5.2579, -5.2552,  2.2640,  ..., -5.2631, -5.2835, -5.2650],
         [ 1.0482,  1.0549,  3.4282,  ...,  1.0684,  1.1276,  0.7290],
         [ 0.8263,  0.8238,  3.5503,  ...,  0.8330,  0.8313,  0.5981],
         [ 0.8692,  0.8614,  3.1840,  ...,  0.8806,  0.8476,  0.6761],
         [ 0.9691,  0.9599,  3.1258,  ...,  0.9846,  0.9382,  0.7884],
         [ 1.0963,  1.0847,  3.1242,  ...,  1.1106,  1.0566,  0.9201]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)

So the quantization error is large. Since we dequantize the weight, and then add lora weight, and quantize the new weight, the loss error will be larger than the previous example.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for explaining, I have quickly played with your script. I think that despite the logits are slightly different the argmax should stay the same, I believe this would be a sufficient test.

import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

torch.manual_seed(1024)
model_origin = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-125m",
    torch_dtype=torch.float32,
)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_compute_type=torch.float32
)
model = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-125m",
    quantization_config=bnb_config,
    torch_dtype=torch.float32,
)
random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)

print(model_origin.generate(random_input.clone().to(model_origin.device), max_new_tokens=1))
print(model.generate(random_input.clone().to(model_origin.device), max_new_tokens=1))

This should give:

tensor([[    1,     0,     1,     0,     1,     0, 50118]])
tensor([[    1,     0,     1,     0,     1,     0, 50118]])

Could you add a test that tests that the first generated token is the same between the merged and non-merged model? 🙏 Thanks again for your work on this!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran the 8bit test locally on my GPU and see large differences in the logits (l0 are original logits, l1 merged):

>>> torch.testing.assert_close(l0, l1)
*** AssertionError: Tensor-likes are not close!
Mismatched elements: 301245 / 301632 (99.9%)
Greatest absolute difference: 9.671875 at index (0, 5, 43954) (up to 1e-05 allowed)
Greatest relative difference: 29409.54607899256 at index (0, 2, 1238) (up to 0.001 allowed)

>>> torch.abs(l0-l1).mean() / torch.abs(l0).mean()
tensor(0.6348, device='cuda:0', dtype=torch.float16)

>>> l0.mean()
tensor(-2.6367, device='cuda:0', dtype=torch.float16)
>>> l1.mean()
tensor(-3.7637, device='cuda:0', dtype=torch.float16)

Also plotted a sample of logits against another:

Screenshot from 2023-08-25 13-54-44

In contrast, this is what I find for 4bit:

>>> torch.testing.assert_close(l0, l1)
*** AssertionError: Tensor-likes are not close!

Mismatched elements: 301627 / 301632 (100.0%)
Greatest absolute difference: 1.9874346256256104 at index (0, 4, 5985) (up to 1e-05 allowed)
Greatest relative difference: 4337.804713804714 at index (0, 3, 3491) (up to 1.3e-06 allowed)

>>> torch.abs(l0-l1).mean() / torch.abs(l0).mean()
tensor(0.1003, device='cuda:0')
>>> l0.mean()
tensor(-2.9331, device='cuda:0')
>>> l1.mean()
tensor(-3.0172, device='cuda:0')

And the scatter plot:

Screenshot from 2023-08-25 14-00-38

The differences can also be quite big but it's much better than 8bit. Most notably, 8bit has much higher deviation and bias.

As is, I think I wouldn't feel comfortable exposing 8bit merging to users, whereas 4bit looks fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small update, I changed the code to set the lora weights to 0 and still found the deviations to be of the same magnitude. So it looks like it's just the back and forth conversion that introduces the error, not the addition of lora weights.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great insights Benjamin!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating @jiqing-feng , you have some merge conflicts in the PR with main branch, would you be happy to resolve them? Otherwise happy to help you!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@younesbelkada
Copy link
Contributor

hi @jiqing-feng can you try to run

make style && make quality

@jiqing-feng
Copy link
Contributor Author

hi @jiqing-feng can you try to run

make style && make quality

Hi @younesbelkada ,
Thanks for your help. I have fixed all the code format issue, could you please help me review and merge it?

@younesbelkada
Copy link
Contributor

Sure, @jiqing-feng give me a moment and I will get back to you

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jiqing-feng
Thanks very much for iterating!
Ran the tests and the 8bit test fails.

    def test_8bit_merge_lora(self):
        torch.manual_seed(1000)
        model = AutoModelForCausalLM.from_pretrained(
            "facebook/opt-125m",
            load_in_8bit=True,
        )
        config = LoraConfig(
            r=8,
            init_lora_weights=False,
        )
        model = get_peft_model(model, config)
    
        random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)
        with torch.inference_mode():
            out_before_merge = model.generate(random_input, max_new_tokens=1)
    
        model.merge_and_unload("default")
        with torch.inference_mode():
            out_after_merge = model.generate(random_input, max_new_tokens=1)
    
>       self.assertTrue(torch.equal(out_before_merge, out_after_merge))
E       AssertionError: False is not true

Do you have an idea why the test is failing for 8bit? It is quite surprising that the differences are quite high for the 8bit model and not the 4bit model

warnings.warn("Already merged. Nothing to do.")
return
if self.r[self.active_adapter] > 0:
lora_data = self.get_delta_weight(self.active_adapter)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment above, I think that we need to warn users that merging quantized models can lead to getting different generations due to rounding errors

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Aug 25, 2023

Hi @younesbelkada .

The test was successful on my node but it may fail if I change the seed.

As you said there is precision loss when we apply quantization. It is hard to ensure the result is exactly the same, so adding a warning is the best choice. Users can choose by themselves if they want to keep precision or get higher performance.

Besides, we set bnb_4bit_use_double_quant=False and the 4-bit model uses block-wise quantization, so the 4-bit model can keep more precision in this case. The 8-bit model uses vector-wise quantization which has lower precision than block-wise quantization

new_module = bnb.nn.Linear8bitLt(
target.in_features,
target.out_features,
bias,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a small thing, but you could you please use keyword arguments starting from the bias argument, i.e.

                    new_module = bnb.nn.Linear8bitLt(
                        target.in_features,
                        target.out_features,
                        bias=bias,
                        # etc.

Same change below for 4bit. This should be more future proof and conforms with the use of kw arguments in bnb.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Aug 26, 2023

Hi @younesbelkada @BenjaminBossan

You are right. We cannot expose 8bit merging to users because the int8 weight data format was changed during forward, so we cannot dequantize 8bit parameters once they are transformed. The original 8bit parameter format has been deleted, see this.

So I will remove the merge API for the 8bit model, and keep the 4bit merge.

I will figure out how to dequantize 8bit parameters and enable merge 8bit model in the future.

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @jiqing-feng for adding support for merging of 4bit layers, very impactful! Left a couple comments.

Comment on lines 576 to 587
elif is_bnb_available() and isinstance(target, bnb.nn.Linear8bitLt):
bias = target.bias is not None
new_module = bnb.nn.Linear8bitLt(
target.in_features,
target.out_features,
bias=bias,
has_fp16_weights=target.state.has_fp16_weights,
memory_efficient_backward=target.state.memory_efficient_backward,
threshold=target.state.threshold,
index=target.index,
device=target.weight.device,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these should be removed if the support for 8 bit merging isn't being enabled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@@ -1193,8 +1216,48 @@ def __init__(
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name

def merge(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A reference for merging of 4 bit weights that was shared on Twitter by Tim Dettmers: https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great, left one comment! Can you also add a reference to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930 in the docstring of merge ?
Thanks for all your work on this!

)
kwargs = self.weight.__dict__
lora_data = self.get_delta_weight(self.active_adapter)
w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state) + lora_data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure but I think that you need to specify the quant_type here and below

Suggested change
w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state) + lora_data
w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state, quant_type=self.weight.quant_type) + lora_data

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment. Actually, there is no need to pass quant_type because it is already in quant_state.

I have added a comment that refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930 before 4bit merge

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@flozi00
Copy link

flozi00 commented Aug 28, 2023

Amazing, this enables quantized relora implementation using peft :-)
Thanks for this feature

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now looks good, thanks for all the work. I'm sure there's going to be a solution for 8bit eventually.

@pacman100 pacman100 merged commit 140a69b into huggingface:main Aug 28, 2023
@jiqing-feng jiqing-feng deleted the merge_lora_kbit branch August 29, 2023 04:38
BenjaminBossan added a commit that referenced this pull request Aug 29, 2023
For each tuner, created a sub-module that contains at least:

- config.py for config stuff
- model.py for the actual model/encoder/embedding
- __init__.py so that imports are preserved

Then, when there was a need, further files were created, like layer.py
or utils.py.

Imports were changed to absolute imports everywhere, except for the
sub-packages within a tuner directory, as these packages will always 
stay together in the same place.

For some existing modules, the license comment of the top of the file
was missing, I always added it.

There was a bug in the forward method of 4bit linear lora layers introduced
in #851, for the case that the model is merged AND adapters are disabled.
For that scenario, we need to unmerge first before generating the output,
same as we do for the vanilla Linear layer. This step was missing from the
code previously and is now implemented correctly. Tests were adjusted to
catch that error.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Merge LoRA Adapter with int8 base model.
6 participants