-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding a brand new fine-tuning method, Bone. #2148
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR to add the Bone method to PEFT. It's already looking fine but I found a few issues, please check out my comments.
Apart from those, to get to the next stage, could you please:
- Run
make style
- Clarify the following: Is this implementation using row or column blocks? Would it make sense to add both as an option?
- Related to that, depending on the option and the choice of
r
, we probably need to add some sanity checks on the shape of the targeted weight matrix to make sure that we can evenly divide it into blocks, right? - Let's add a few tests cases to get started. For this, check these lines and implement something similar for Bone. Then you can run the tests using
pytest tests/test_custom_models.py -k bone -v
to help identify potential issues.
src/peft/tuners/bone/layer.py
Outdated
orig_weight = self.base_layer.weight.data.clone() | ||
delta_weight = self.get_delta_weight(active_adapter, orig_weight) | ||
orig_weight += delta_weight | ||
|
||
result = F.linear(input=x, weight=orig_weight, bias=self.base_layer.bias) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just an idea from my side:
Right now, you're calculating W2 = W + dW
, then result = W2@x + bias
. Instead, I wonder if we could re-arrange this somehow to use result = self.base_layer(x, *args, **kwargs)
, then result = result + bone_result
. The bone_result
would ideally be calculated in a parameter-efficient way, which should hopefully alleviate the memory overhead that Bone has compared to LoRA. For this, I wonder if the Bone contribution can be calculated group-wise and added incrementally instead of instantiating a big dW
, so basically trading more compute for less memory. Do you think that would be possible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think 'result = result + bone_result' directly is not feasible. The Bone repository has flash-bone that integrates linear and bone calculations, but I'm not very skilled at writing operators, so the current optimization speed is insufficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have Y = (W + W@dW + dW)@X + b
, we could rewrite it as Y = W@X + (W@dW + dW)@X + b
, right? And since W@X + b
is the base result, i.e. self.base_layer(x, *args, **kwargs)
, this should work?
Another question would be if we can re-define delta_weight
in a way that would allow to reduce the memory overhead.
My thought was if we can somehow calculate the activations per block and add them incrementally to the base result. Maybe it's not possible or it would not help to reduce memory, but if we can trade off compute for less memory, it is usually a worthwhile trade. If it's simple not possible, feel free to ignore my comment :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to add checkpointing specifically for bone_result? The optimization here indeed requires careful consideration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would not know how. I think users should just use the existing checkpointing utilities if they want to make use of this.
I believe it can be merged now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates to the PR. We still have a few boxes to tick before we can merge though.
Currently, it is possible to fine-tune models like LLaMA normally, but due to the checks mentioned in issue 3, some structures do not pass.
This needs to be addressed. Can we change the values so that the check passes? In the worst case, we can define a new module class where the shapes fit. The goal is that all tests are green.
Also:
- Great additions to the test but we should also add tests for the other model types. For instance, check how HRA did it.
- There is still some commented code here and there.
- We need an entry to the docs, again, check HRA. It's important to mention to pros and cons of the method.
- Let's also add an example that users can use to get started. This will greatly help with adoption. Ideally, this would be a script similar to the experiments in your paper/original repo. That way, we can verify that the PEFT implementation is correct.
- And after you finished, don't forget to call
make style
.
src/peft/tuners/bone/layer.py
Outdated
orig_weight = self.base_layer.weight.data.clone() | ||
delta_weight = self.get_delta_weight(active_adapter, orig_weight) | ||
orig_weight += delta_weight | ||
|
||
result = F.linear(input=x, weight=orig_weight, bias=self.base_layer.bias) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would not know how. I think users should just use the existing checkpointing utilities if they want to make use of this.
src/peft/tuners/bone/layer.py
Outdated
raise TypeError(f"Bone is not implemented for base layers of type {type(base_layer).__name__}") | ||
|
||
# Initialize weights | ||
# if init_weights: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you consider this?
Most of the tests have now passed, but due to Bone being initialized to 0, an error occurs when encountering |
Thanks for working on those. Regarding that, I have a suggestion. Right now, we have: if init_weights:
self.reset_bone_parameters(adapter_name, r)
...
def reset_bone_parameters(self, adapter_name: str, r):
self.bone_block[adapter_name] = nn.Parameter(torch.zeros(self.out_features // r, r, r), requires_grad=True) I would suggest that for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the updates. We're making good progress. I was checking the tests and still found quite a few failures. Btw you can run them locally using pytest tests/ -k "bone"
.
Currently, tests involving GPT2 are failing because Bone does not support Conv1D
layers. You can skip these tests, check how other methods achieve that:
peft/tests/test_decoder_models.py
Lines 70 to 83 in e8259ff
def skip_adalora_or_oft_or_hra_and_gpt2(test_list): | |
return [ | |
test | |
for test in test_list | |
if not ( | |
("GPT2LMHeadModel" in test[1]) | |
and ( | |
(test[2] == AdaLoraConfig) | |
or (test[2] == BOFTConfig) | |
or (test[2] == HRAConfig) | |
or (test[2] == OFTConfig) | |
) | |
) | |
] |
We also have multiple tests failing that require Bone to be initialized as non-identity transform. For those tests to pass, we only need to ensure that init_weights=False
is passed to the config. Please check how this is achieved for other methods:
peft/tests/test_decoder_models.py
Lines 212 to 263 in e8259ff
@parameterized.expand( | |
PeftTestConfigManager.get_grid_parameters( | |
{ | |
"model_ids": PEFT_DECODER_MODELS_TO_TEST, | |
"lora_kwargs": {"init_lora_weights": [False]}, | |
"adalora_kwargs": {"init_lora_weights": [False]}, | |
"ia3_kwargs": {"init_ia3_weights": [False]}, | |
"boft_kwargs": {"init_weights": [False]}, | |
"oft_kwargs": {"init_weights": [False]}, | |
"vera_kwargs": {"init_weights": [False]}, | |
"fourierft_kwargs": {"init_weights": [False]}, | |
"hra_kwargs": {"init_weights": [False]}, | |
"task_type": "CAUSAL_LM", | |
}, | |
) | |
) | |
def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): | |
self._test_merge_layers(model_id, config_cls, config_kwargs) | |
@parameterized.expand( | |
PeftTestConfigManager.get_grid_parameters( | |
{ | |
"model_ids": PEFT_DECODER_MODELS_TO_TEST, | |
"lora_kwargs": {"init_lora_weights": [False]}, | |
"ia3_kwargs": {"init_ia3_weights": [False]}, | |
"boft_kwargs": {"init_weights": [False]}, | |
"oft_kwargs": {"init_weights": [False]}, | |
"vera_kwargs": {"init_weights": [False]}, | |
"fourierft_kwargs": {"init_weights": [False]}, | |
"hra_kwargs": {"init_weights": [False]}, | |
"task_type": "CAUSAL_LM", | |
}, | |
filter_params_func=skip_oft_or_hra_and_gpt2, | |
) | |
) | |
def test_merge_layers_multi(self, test_name, model_id, config_cls, config_kwargs): | |
self._test_merge_layers_multi(model_id, config_cls, config_kwargs) | |
@parameterized.expand( | |
PeftTestConfigManager.get_grid_parameters( | |
{ | |
"model_ids": PEFT_DECODER_MODELS_TO_TEST, | |
"lora_kwargs": {"init_lora_weights": [False]}, | |
"ia3_kwargs": {"init_ia3_weights": [False]}, | |
"boft_kwargs": {"init_weights": [False]}, | |
"oft_kwargs": {"init_weights": [False]}, | |
"task_type": "CAUSAL_LM", | |
}, | |
) | |
) | |
def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs): | |
self._test_merge_layers_nan(model_id, config_cls, config_kwargs) |
peft/tests/test_encoder_decoder_models.py
Lines 92 to 104 in e8259ff
@parameterized.expand( | |
PeftTestConfigManager.get_grid_parameters( | |
{ | |
"model_ids": PEFT_ENCODER_DECODER_MODELS_TO_TEST, | |
"lora_kwargs": {"init_lora_weights": [False]}, | |
"adalora_kwargs": {"init_lora_weights": [False]}, | |
"ia3_kwargs": {"init_ia3_weights": [False]}, | |
"vera_kwargs": {"init_weights": [False]}, | |
"hra_kwargs": {"init_weights": [False]}, | |
"task_type": "SEQ_2_SEQ_LM", | |
}, | |
) | |
) |
peft/tests/test_feature_extraction_models.py
Lines 109 to 123 in e8259ff
@parameterized.expand( | |
PeftTestConfigManager.get_grid_parameters( | |
{ | |
"model_ids": PEFT_FEATURE_EXTRACTION_MODELS_TO_TEST, | |
"lora_kwargs": {"init_lora_weights": [False]}, | |
"adalora_kwargs": {"init_lora_weights": [False]}, | |
"ia3_kwargs": {"init_ia3_weights": [False]}, | |
"boft_kwargs": {"init_weights": [False]}, | |
"oft_kwargs": {"init_weights": [False]}, | |
"vera_kwargs": {"init_weights": [False]}, | |
"hra_kwargs": {"init_weights": [False]}, | |
"task_type": "FEATURE_EXTRACTION", | |
}, | |
) | |
) |
(there are more instances in the same files, please ensure that all are covered)
Let's fix those issues and then hopefully all the tests should pass.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for further working on fixing the tests. One part is not working, but I added a suggestion that will fix the problem.
Also, since this PR is now getting close to the finish line, did you have time to work on these point I mentioned earlier?
- We need an entry to the docs, again, check HRA. It's important to mention to pros and cons of the method.
- Let's also add an example that users can use to get started. This will greatly help with adoption. Ideally, this would be a script similar to the experiments in your paper/original repo. That way, we can verify that the PEFT implementation is correct.
|
You could check this PR and how the docs are updated there (ignore the rest).
That would be a good addition. Let's make it easy for users to pick this method up and adapt it to their own needs. |
The aforementioned issues have all been resolved. |
I'm sorry, due to a conflict with the official PEFT branch, I mistakenly merged it, which caused my previous repository to be deleted and the PR to be forcibly closed. I have now fixed the issue. Should I resubmit or can you reopen my old PR? |
If you can somehow salvage this branch and re-open this PR, it would be better, since that allows us to have the whole discussion in one place. But don't worry if this doesn't work, we can also continue the discussion on the new PR. Just let me know your decision. |
I am really sorry that I could not recover the old branch. Let’s discuss further question under the new PR, which is ready to be tested. |
In the paper "https://arxiv.org/pdf/2409.15371", we introduce a brand new PEFT (Parameter-Efficient Fine-Tuning) method: BONE (BLOCK AFFINE). This is a completely new structure that differs from the LoRA series. In terms of performance, it has already surpassed PISSA.