Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for an optional initialization strategy OLoRA #1828

Merged
merged 11 commits into from
Jun 12, 2024

Conversation

tokenizer-decode
Copy link
Contributor

@tokenizer-decode tokenizer-decode commented Jun 5, 2024

In this paper we introduced OLoRA, an optional initialization strategy that converges significantly faster and ultimately achieves better performance than the standard LoRA implementation. It can be implemented with a single function. I closely followed other init methods so this implementation should be in PEFT style. If you'd lke this to be an optional method, I can go ahead and implement examples and tests.

image

@BenjaminBossan
Copy link
Member

Thanks for proposing this new method for initializing the LoRA weights. In general, I think we can for sure add it to PEFT. As you mentioned, having some tests and documentation would be appreciated.

Implementation-wise, I only had a quick look at the paper and code for now. One part could make things a bit more complicated:

weight_tensor.data -= scale_factor * self.lora_B[adapter_name].weight @ self.lora_A[adapter_name].weight

Here the base weights of the model are mutated, which can have some negative side effects. In fact, this is basically the same problem we also faced when adding PiSSA, so I'll just link to my comment there. If you could share your insights on this, that would be great.

Furthermore, as is, I think this would not work with QLoRA, right? But I think not much would be needed to make it work. Do you think this extra functionality could be added?

@tokenizer-decode
Copy link
Contributor Author

Regarding mutation, I wouldn't expect significant side effects if we're meticulous about checking the types of scale factor inputs, weights etc.. And given types are correct, mathematically speaking, translating A by QR decomp of A should not move A far away from its initial point. However, in the worst-case scenario, we can adopt a similar approach to PiSSA.

Regarding quantization, it's been a significant challenge. I'd really like this to be quantization friendly but firstly, Torch QR doesn't support anything lower than 16-bit. We could implement a custom QR decomposition to support lower bit depths, but I'm skeptical about its feasibility and potential performance. As it stands, it's not compatible with QLoRA.

Additionally, based on my experiments, I've found that this doesn't work with Torch.compile. When I enable compilation, I have to disable safe serialization. I'm unsure whether this is a problem with Torch compile or safe serialization, but I don't think it's an issue with OLoRA.

So I'll start with adding tests, checks for mutation issue, documentation, and we build from there?

@BenjaminBossan
Copy link
Member

Regarding mutation, I wouldn't expect significant side effects if we're meticulous about checking the types of scale factor inputs, weights etc.

This is not so much about the LoRA A and B weights, but about the weights of the base model, which are mutated in the cited line. Let's say I want to load to LoRA adapters, one normal one and one trained with OLoRA. As OLoRA changes the base weights, my normal LoRA adapter may suddenly stop working correctly, as it assumes that the base weights are not touched. This is what I was referring to in my comment.

Also, I wonder if we save and load an OLoRA model, do we have to ensure that the random seed is exactly the same so that the base weights are modified exactly identically? If yes, this could be an easy source of errors.

Did you run tests to check how much of a difference it makes if the base weights were not changed? If we could avoid it, it would make things a lot easier.

Regarding quantization, it's been a significant challenge. I'd really like this to be quantization friendly but firstly, Torch QR doesn't support anything lower than 16-bit. We could implement a custom QR decomposition to support lower bit depths, but I'm skeptical about its feasibility and potential performance. As it stands, it's not compatible with QLoRA.

I was wondering if we could temporarily dequantize the quantized weights, that way, all operations could be performed on floats. It's not ideal but it would be better than not offering this at all. Of course, this can still be problematic when the base weights have to be modified.

Additionally, based on my experiments, I've found that this doesn't work with Torch.compile.

This is fine, we don't support torch.compile across the board, as documented here.

When I enable compilation, I have to disable safe serialization.

This sounds strange, do you mean in general or just with OLoRA? If in general, please open an issue and show us the error.

So I'll start with adding tests, checks for mutation issue, documentation, and we build from there?

Exactly, thanks.

@tokenizer-decode
Copy link
Contributor Author

tokenizer-decode commented Jun 5, 2024

Mutating Weights Update
Before tackling the quantization issue, I addressed the mutation problem. Unfortunately, my tests show that disabling base weights mutation degrades performance.

I took the liberty of making some changes to create a uniform approach for OLoRA and PiSSA, which share similarities in updating base weights.

Rationale:

  • I avoided cluttering the save_pretrained signature with an additional argument (e.g., convert_olora_to_lora).
  • The subtraction method would be identical, resulting in duplicate methods with the same behavior.
  • This approach allows others to experiment with alternative decompositions by simply creating an X_init method.

I'll explore temporarily dequantizing the weights and replicate the safe serialization problem.

Before proceeding, I'd appreciate your thoughts on these changes, which affect the existing API. I'm open to creating separate options and methods, but I believe that would unnecessarily pollute the code.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates to the PR. I haven't done an in depth review yet, as I think we still have to flesh out some designs.

I agree that if the mechanism for updating the weights can be shared with PiSSA, it should be done so. However, we should not remove the convert_pissa_to_lora argument, as this would be a breaking change.

Instead, please leave that argument and add the new one also. Then add a check if convert_pissa_to_lora is not None, in which case a deprecation warning should be given with the instruction to use the new argument in the future, then set the new argument to the value of convert_pissa_to_lora. That way, existing code continues to work but users are nudged towards using the new argument.

Regarding OLoRA + quantized weights, if it's not so straightforward, just skip it for now. We can always come back to this in a future PR. I would just add an extra error message if the weights are quantized to explicitly say that OLoRA + quantization is not yet supported.

examples/olora_finetuning/olora_finetuning.py Show resolved Hide resolved
src/peft/peft_model.py Outdated Show resolved Hide resolved
src/peft/peft_model.py Outdated Show resolved Hide resolved
@tokenizer-decode
Copy link
Contributor Author

I believe it should be good to review now.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates, the code looks quite good and the PR is on a good path to be merged. I still saw a few issues, please check my comments.

examples/olora_finetuning/README.md Show resolved Hide resolved
examples/olora_finetuning/README.md Show resolved Hide resolved
src/peft/peft_model.py Outdated Show resolved Hide resolved
src/peft/peft_model.py Outdated Show resolved Hide resolved
src/peft/peft_model.py Show resolved Hide resolved
src/peft/tuners/lora/layer.py Outdated Show resolved Hide resolved
src/peft/tuners/lora/layer.py Outdated Show resolved Hide resolved
src/peft/tuners/lora/layer.py Outdated Show resolved Hide resolved
tests/test_initialization.py Outdated Show resolved Hide resolved
examples/olora_finetuning/olora_finetuning.py Outdated Show resolved Hide resolved
@tokenizer-decode
Copy link
Contributor Author

So okay, I improved the docs and examples. float and int8 is a good place to start to support quantization.

examples/olora_finetuning/README.md Outdated Show resolved Hide resolved
examples/olora_finetuning/README.md Outdated Show resolved Hide resolved
examples/olora_finetuning/README.md Outdated Show resolved Hide resolved
src/peft/tuners/lora/config.py Outdated Show resolved Hide resolved
src/peft/tuners/lora/layer.py Outdated Show resolved Hide resolved
tests/test_initialization.py Outdated Show resolved Hide resolved
src/peft/tuners/lora/layer.py Outdated Show resolved Hide resolved
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the new additions and adjustments. This looks quite good, only a few smaller comments left. Please take a look.

@tokenizer-decode
Copy link
Contributor Author

Okay done. Additional quant and warning tests added. Doc improved. Quant warning fixed. Please take a look at the final commit.

@tokenizer-decode
Copy link
Contributor Author

tokenizer-decode commented Jun 7, 2024

Also something's just been itching me. We have an option reserved for OLoRA and PiSSA in save_pretrained method. For some reason (I'm also not sure why, I only have empirical evidence for this) mutating base weights results in pretty good performance increase compared to its non-mutating counterpart. Since we mutate the base weights we do have to perform some tricks to mitigate the negative effects that you mentioned. I tried to make this approach as uniform as possible, so that in the future if someone comes up with a new init strategy and wants to mutate the base weights as we do, he/she can use our method easily with relatively little hustle.

However, from a strictly design perspective, I think save_pretrained is too generic to have an option reserved strictly for something as niche as our fine-tuning strategy. First of all, I think very few people will use multiple adapters, most of them will use a single adapter and call it a day (or that's just my observation), but of course it's good that we're supporting this and we should do so. But there should be a cleaner way to do this without polluting save_pretrained method. I just don't know yet. Maybe this option can be carried to LoraConfig and since we know the user is using OLoRA, we can automatically do conversion if we're asked to do so. Just a rudimentary idea:

config = LoraConfig(
    init_lora_weights = "olora",
    wanna_convert = "path_to_dir"
    )

_convert_olora(base_model, wanna_convert) # We internally do this since we know `wanna_convert`
                                          # and the user doesn't see this
train()
save_pretrained(output_dir) # Save as you would normally

This also eliminates manual saving in the first place. This is just a sketch. Main idea is to not mess with save_pretrained.
I don't know. Even if we should do this, that's for another PR. Does it make sense to you? Wdyt?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latest changes look quite good. CI is still failing because of line lengths, could you please run make style on the code?

tests/test_gpu_examples.py Outdated Show resolved Hide resolved
@tokenizer-decode
Copy link
Contributor Author

Weird that I don't get any errors from make style. It did change some irrelevant functions though. Please check the latest commit.

@BenjaminBossan
Copy link
Member

I tried to make this approach as uniform as possible, so that in the future if someone comes up with a new init strategy and wants to mutate the base weights as we do, he/she can use our method easily with relatively little hustle.

Yes, that is a great addition.

I think save_pretrained is too generic to have an option reserved strictly for something as niche as our fine-tuning strategy.

It is a bit specific, but I don't think the cost of adding this option there is too high.

First of all, I think very few people will use multiple adapters, most of them will use a single adapter and call it a day

Probably a single adapter is the most common case, but I would not underestimate the number of users who are using more advanced features. Of course I have no way of really knowing, but based on the issues that are being opened, a significant number of users uses multiple adapters.

Maybe this option can be carried to LoraConfig and since we know the user is using OLoRA, we can automatically do conversion if we're asked to do so.

This also eliminates manual saving in the first place. This is just a sketch. Main idea is to not mess with save_pretrained.
I don't know. Even if we should do this, that's for another PR. Does it make sense to you? Wdyt?

Thanks for thinking about and proposing an alternative design. If I understand your proposal correctly, what I think could is that it might create the impression that passing this to the config is all it takes to take care of this issue. But once training is over, we would still require the user to call save_pretrained and we would need to add logic there to perform the conversion, right? If the user just loads another LoRA adapter after training, it won't work properly.

We could try to call this conversion automatically after training, but this is also problematic. First of all, PEFT does not provide training code, so how would we know when training is over? Second, even if training is over, we don't know if the user may not want to continue training, in which case we should not convert.

Furthermore, if we automated this, it would mean that we need to save a copy of the model when we initialize the PEFT model. This could be quite surprising to users.

What I like about the current approach, even if it's cumbersome, is that is very explicit. Therefore, it is less likely that the user will be surprised by anything that happens.

What I could imagine that we could do is to factor out the conversion into its own function. That way, the user could explicitly call it once they're finished with training. The save_pretrained method could use this same function under the hood if the path_initial_model_for_weight_conversion is passed. But then we should ideally have a check to ensure that there is no double conversion.

@BenjaminBossan
Copy link
Member

Weird that I don't get any errors from make style. It did change some irrelevant functions though. Please check the latest commit.

Hmm, could it be that you have a different ruff version? The CI installs version 0.2.2.

@tokenizer-decode
Copy link
Contributor Author

My bad that I didn't have doc-builder. Now it should work I suppose.

Thanks a lot for your insights about save_pretrained issue. I might create a new PR in the future and we can talk about improving user-experience there.

Also you've been super helpful in the process. Thanks.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're almost good to go, I only found a few minor issues, please take a look.

examples/olora_finetuning/README.md Outdated Show resolved Hide resolved
examples/olora_finetuning/olora_finetuning.py Outdated Show resolved Hide resolved
examples/olora_finetuning/README.md Show resolved Hide resolved
@tokenizer-decode
Copy link
Contributor Author

tokenizer-decode commented Jun 10, 2024

Nice catch, thanks. I pushed 2-3 commits tonight, but each one was a fix for the previous one. So, I deleted them and pushed a final commit. As a result, the order of my comments might be mixed up. Please see my final commit for clarity.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks excellent, thanks for your contribution and ensuring that OLoRA fits well into the PEFT framework.

I was about to approve when I saw something I can't quite explain:

During the tests, when we want to save the initial model, we toggle init_lora_weights temporarily to True and then switch back to "olora":

https://github.com/huggingface/peft/pull/1828/files#diff-df9ecc7077bee932f56e76161ada47693d73acd3ed175a5b9a9158cfe03ec381R372-R374

I was about to comment that this should be added to the example that you provided. But out of curiosity, I removed the toggle from the test and found that it still passes. Now I wonder if those lines are unnecessary or if the test is not quite correct. Do you have an idea?

Btw. the same is true for the PiSSA tests, i.e. we can remove the toggle there and they also pass, so whatever this is is also affecting PiSSA.

@tokenizer-decode
Copy link
Contributor Author

tokenizer-decode commented Jun 11, 2024

Hmm, what part did you remove? Because removing the toggle i.e. peft_model.peft_config["default"].init_lora_weights = "olora" fails the test on my machine. If you removed the whole section, I mean this part:

        peft_model.peft_config["default"].init_lora_weights = True
        peft_model.save_pretrained(tmp_path / "init-model")
        peft_model.peft_config["default"].init_lora_weights = "olora"

That would break the test.

@BenjaminBossan
Copy link
Member

Sorry, I should have been more precise, I meant this:

-       peft_model.peft_config["default"].init_lora_weights = True
        peft_model.save_pretrained(tmp_path / "init-model")
-       peft_model.peft_config["default"].init_lora_weights = "olora"

@tokenizer-decode
Copy link
Contributor Author

tokenizer-decode commented Jun 11, 2024

Nice! That still passes because I internally set this to True with my last commit. If you remove my last commit from peft_model.py and turn off the toggle you'd see that the test fails. PiSSA authors needed that because they did not correctly set this field, and we adapted the tests from them. Now we don't need those lines. I think let's just keep those lines so that they can see in the future and they may have comments.

@BenjaminBossan
Copy link
Member

Ah I see, that's a nice addition. Let's remove the toggle from the OLoRA tests then, as they are just adding code noise. As for PiSSA, this should also be removed from the tests and from the docs, but that can be done in a later PR.

@tokenizer-decode
Copy link
Contributor Author

Done. @BenjaminBossan

@tokenizer-decode
Copy link
Contributor Author

This is a server issue right?

@BenjaminBossan
Copy link
Member

Yes, nothing to worry about, for some reasons we get many timeouts lately. I'll check in on occasion and restart :)

@tokenizer-decode
Copy link
Contributor Author

Tests look good. Anything left on my end? @BenjaminBossan

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ping, I was distracted by other work and forgot to check back on this PR.

Everything looks super good. Thanks for adding the OLoRA initialization method and refactoring the existing code to make similar additions easier in the future. Also good discussions on different designs.

@BenjaminBossan BenjaminBossan merged commit 2f5360a into huggingface:main Jun 12, 2024
14 checks passed
@tokenizer-decode
Copy link
Contributor Author

Thanks a lot for helping me with the process. Great work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants