-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support merge lora module for 4bit and 8bit linear #851
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great ! Thanks for your great work on this, before merging, can you add 2 tests in this file: https://github.com/huggingface/peft/blob/main/tests/test_common_gpu.py one for 8bit models and another for 4bit models. For testing purpose you can use a small model such as facebook/opt-350m
.
Hi @younesbelkada , I have added the test for merging lora module to 8bit and 4bit model. Would you please help to review it? Thx! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks ! Added a comment, can you add a logits or generations check inside the tests you have designed ? 🙏 Apart from that it looks quite nice!
tests/test_common_gpu.py
Outdated
self.assertTrue( | ||
isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, bnb.nn.Linear8bitLt) | ||
) | ||
self.assertTrue( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a logits or generation test for that ? you can check the logits of the model before and after merging and make sure they are the same. Make sure to add init_lora_weights=False
when creating the LoraConfig to avoid initializing the B matrix with full zeros: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora.py#L98
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @younesbelkada
It is hard to keep the result of forward logits or generation text the same because both dequantization and quantization have precision loss. You can try the following script
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
torch.manual_seed(1024)
model_origin = AutoModelForCausalLM.from_pretrained(
"opt-125m",
torch_dtype=torch.float32,
)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_type=torch.float32
)
model = AutoModelForCausalLM.from_pretrained(
"opt-125m",
quantization_config=bnb_config,
torch_dtype=torch.float32,
)
random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)
print(model_origin(random_input.clone().to(model_origin.device)).logits)
print(model(random_input).logits)
The output should be
tensor([[[-3.9464, -3.9443, 3.2428, ..., -3.9583, -3.9531, -4.0695],
[ 2.1592, 2.1657, 3.4937, ..., 2.1609, 2.2064, 1.8032],
[ 1.7407, 1.7366, 3.3675, ..., 1.7371, 1.7456, 1.5533],
[ 1.7655, 1.7583, 3.1950, ..., 1.7668, 1.7512, 1.5889],
[ 1.8861, 1.8778, 3.0723, ..., 1.8935, 1.8632, 1.7100],
[ 2.0078, 1.9973, 2.9639, ..., 2.0184, 1.9726, 1.8311]]],
grad_fn=<UnsafeViewBackward0>)
tensor([[[-5.2579, -5.2552, 2.2640, ..., -5.2631, -5.2835, -5.2650],
[ 1.0482, 1.0549, 3.4282, ..., 1.0684, 1.1276, 0.7290],
[ 0.8263, 0.8238, 3.5503, ..., 0.8330, 0.8313, 0.5981],
[ 0.8692, 0.8614, 3.1840, ..., 0.8806, 0.8476, 0.6761],
[ 0.9691, 0.9599, 3.1258, ..., 0.9846, 0.9382, 0.7884],
[ 1.0963, 1.0847, 3.1242, ..., 1.1106, 1.0566, 0.9201]]],
device='cuda:0', grad_fn=<UnsafeViewBackward0>)
So the quantization error is large. Since we dequantize the weight, and then add lora weight, and quantize the new weight, the loss error will be larger than the previous example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks for explaining, I have quickly played with your script. I think that despite the logits are slightly different the argmax should stay the same, I believe this would be a sufficient test.
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
torch.manual_seed(1024)
model_origin = AutoModelForCausalLM.from_pretrained(
"facebook/opt-125m",
torch_dtype=torch.float32,
)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_compute_type=torch.float32
)
model = AutoModelForCausalLM.from_pretrained(
"facebook/opt-125m",
quantization_config=bnb_config,
torch_dtype=torch.float32,
)
random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)
print(model_origin.generate(random_input.clone().to(model_origin.device), max_new_tokens=1))
print(model.generate(random_input.clone().to(model_origin.device), max_new_tokens=1))
This should give:
tensor([[ 1, 0, 1, 0, 1, 0, 50118]])
tensor([[ 1, 0, 1, 0, 1, 0, 50118]])
Could you add a test that tests that the first generated token is the same between the merged and non-merged model? 🙏 Thanks again for your work on this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran the 8bit test locally on my GPU and see large differences in the logits (l0
are original logits, l1
merged):
>>> torch.testing.assert_close(l0, l1)
*** AssertionError: Tensor-likes are not close!
Mismatched elements: 301245 / 301632 (99.9%)
Greatest absolute difference: 9.671875 at index (0, 5, 43954) (up to 1e-05 allowed)
Greatest relative difference: 29409.54607899256 at index (0, 2, 1238) (up to 0.001 allowed)
>>> torch.abs(l0-l1).mean() / torch.abs(l0).mean()
tensor(0.6348, device='cuda:0', dtype=torch.float16)
>>> l0.mean()
tensor(-2.6367, device='cuda:0', dtype=torch.float16)
>>> l1.mean()
tensor(-3.7637, device='cuda:0', dtype=torch.float16)
Also plotted a sample of logits against another:
In contrast, this is what I find for 4bit:
>>> torch.testing.assert_close(l0, l1)
*** AssertionError: Tensor-likes are not close!
Mismatched elements: 301627 / 301632 (100.0%)
Greatest absolute difference: 1.9874346256256104 at index (0, 4, 5985) (up to 1e-05 allowed)
Greatest relative difference: 4337.804713804714 at index (0, 3, 3491) (up to 1.3e-06 allowed)
>>> torch.abs(l0-l1).mean() / torch.abs(l0).mean()
tensor(0.1003, device='cuda:0')
>>> l0.mean()
tensor(-2.9331, device='cuda:0')
>>> l1.mean()
tensor(-3.0172, device='cuda:0')
And the scatter plot:
The differences can also be quite big but it's much better than 8bit. Most notably, 8bit has much higher deviation and bias.
As is, I think I wouldn't feel comfortable exposing 8bit merging to users, whereas 4bit looks fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small update, I changed the code to set the lora weights to 0 and still found the deviations to be of the same magnitude. So it looks like it's just the back and forth conversion that introduces the error, not the addition of lora weights.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great insights Benjamin!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating @jiqing-feng , you have some merge conflicts in the PR with main branch, would you be happy to resolve them? Otherwise happy to help you!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
hi @jiqing-feng can you try to run make style && make quality |
Hi @younesbelkada , |
Sure, @jiqing-feng give me a moment and I will get back to you |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @jiqing-feng
Thanks very much for iterating!
Ran the tests and the 8bit test fails.
def test_8bit_merge_lora(self):
torch.manual_seed(1000)
model = AutoModelForCausalLM.from_pretrained(
"facebook/opt-125m",
load_in_8bit=True,
)
config = LoraConfig(
r=8,
init_lora_weights=False,
)
model = get_peft_model(model, config)
random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)
with torch.inference_mode():
out_before_merge = model.generate(random_input, max_new_tokens=1)
model.merge_and_unload("default")
with torch.inference_mode():
out_after_merge = model.generate(random_input, max_new_tokens=1)
> self.assertTrue(torch.equal(out_before_merge, out_after_merge))
E AssertionError: False is not true
Do you have an idea why the test is failing for 8bit? It is quite surprising that the differences are quite high for the 8bit model and not the 4bit model
src/peft/tuners/lora.py
Outdated
warnings.warn("Already merged. Nothing to do.") | ||
return | ||
if self.r[self.active_adapter] > 0: | ||
lora_data = self.get_delta_weight(self.active_adapter) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my comment above, I think that we need to warn users that merging quantized models can lead to getting different generations due to rounding errors
Hi @younesbelkada . The test was successful on my node but it may fail if I change the seed. As you said there is precision loss when we apply quantization. It is hard to ensure the result is exactly the same, so adding a warning is the best choice. Users can choose by themselves if they want to keep precision or get higher performance. Besides, we set |
src/peft/tuners/lora.py
Outdated
new_module = bnb.nn.Linear8bitLt( | ||
target.in_features, | ||
target.out_features, | ||
bias, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a small thing, but you could you please use keyword arguments starting from the bias
argument, i.e.
new_module = bnb.nn.Linear8bitLt(
target.in_features,
target.out_features,
bias=bias,
# etc.
Same change below for 4bit. This should be more future proof and conforms with the use of kw arguments in bnb.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
Hi @younesbelkada @BenjaminBossan You are right. We cannot expose 8bit merging to users because the int8 weight data format was changed during forward, so we cannot dequantize 8bit parameters once they are transformed. The original 8bit parameter format has been deleted, see this. So I will remove the merge API for the 8bit model, and keep the 4bit merge. I will figure out how to dequantize 8bit parameters and enable merge 8bit model in the future. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @jiqing-feng for adding support for merging of 4bit layers, very impactful! Left a couple comments.
src/peft/tuners/lora.py
Outdated
elif is_bnb_available() and isinstance(target, bnb.nn.Linear8bitLt): | ||
bias = target.bias is not None | ||
new_module = bnb.nn.Linear8bitLt( | ||
target.in_features, | ||
target.out_features, | ||
bias=bias, | ||
has_fp16_weights=target.state.has_fp16_weights, | ||
memory_efficient_backward=target.state.memory_efficient_backward, | ||
threshold=target.state.threshold, | ||
index=target.index, | ||
device=target.weight.device, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these should be removed if the support for 8 bit merging isn't being enabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
@@ -1193,8 +1216,48 @@ def __init__( | |||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) | |||
self.active_adapter = adapter_name | |||
|
|||
def merge(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A reference for merging of 4 bit weights that was shared on Twitter by Tim Dettmers: https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking great, left one comment! Can you also add a reference to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930 in the docstring of merge
?
Thanks for all your work on this!
) | ||
kwargs = self.weight.__dict__ | ||
lora_data = self.get_delta_weight(self.active_adapter) | ||
w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state) + lora_data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure but I think that you need to specify the quant_type
here and below
w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state) + lora_data | |
w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state, quant_type=self.weight.quant_type) + lora_data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your comment. Actually, there is no need to pass quant_type
because it is already in quant_state.
I have added a comment that refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930 before 4bit merge
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Amazing, this enables quantized relora implementation using peft :-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This now looks good, thanks for all the work. I'm sure there's going to be a solution for 8bit eventually.
For each tuner, created a sub-module that contains at least: - config.py for config stuff - model.py for the actual model/encoder/embedding - __init__.py so that imports are preserved Then, when there was a need, further files were created, like layer.py or utils.py. Imports were changed to absolute imports everywhere, except for the sub-packages within a tuner directory, as these packages will always stay together in the same place. For some existing modules, the license comment of the top of the file was missing, I always added it. There was a bug in the forward method of 4bit linear lora layers introduced in #851, for the case that the model is merged AND adapters are disabled. For that scenario, we need to unmerge first before generating the output, same as we do for the vanilla Linear layer. This step was missing from the code previously and is now implemented correctly. Tests were adjusted to catch that error.
Hi @younesbelkada @pacman100 @TimDettmers
This PR enables merging lora module for 4-bit and 8-bit linear layers in bitsandbytes.
It dequantizes the k-bit parameter to float type, adds lora module to the float parameter, and then quantizes the new parameter.
Fixes #638