-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH Support Conv2d layers for IA³ #972
ENH Support Conv2d layers for IA³ #972
Conversation
Adds support for Conv2D layers to the IA³ tuner. Tests are added to check that they work. Notes: Unfortunately, when unmerging the Conv2d IA³ layers, there is quite a bit of rounding error. I had to increase the tolerances for this specific test case to make the tests pass. I'm not 100% sure why this is, but I could imagine that for Conv2d, small errors accumulate because of the convolution operation. I also added tests for IA³ Linear layers for the custom models, which also pass. However, there is an error when using Conv1D. The reason is that merging fails because there is a shape mismatch when fan_in_fan_out=True (which is set automatically for Conv1D). I'm not sure how this should be fixed. I also noticed that I don't understand what the feedforward_modules parameter does in IA³. AFAICT, it always has to be set for all IA³ layers, i.e. self.is_feedforward always must be True. If it cannot be False, we could as well remove feedforward_modules and automatically set is_feedforward to True.
The documentation is not available anymore as the PR was closed or merged. |
Ping @SumanthRH |
Hi @BenjaminBossan, regarding the feedfordward layers, this detail is in the gray area, as we can only go by the T-few paper. In the original IA Following the two equations, the current implementation does the following:
This is following the two equations given in the paper (I didn't think we could merge this, because multiplying in the input space gives you a vector with a different dimension than multiplying after matmul). How do we know this is right? I initially went over the original implementation from the authors here. Their implementation is for T5 v1.1 only, and the feedforward logic for all the layers is actually the same : the activations (after matmul) are multiplied by the IA Also, I'm happy to take a look at the Conv1D shape issue. |
Thanks @SumanthRH for the detailed explanation. Mathematically, it was clear to me what the difference is, but there were still sources of confusion. I think the first one stems from the name "feedforward", which can mean different things depending on circumstances. The way it is used in the paper is to designate a very specific part of the model architecture, but this doesn't necessarily generalize to other architectures. I think in hindsight, we should have named the variable for what it actually does, namely decide if the weight is multiplied before or after. I think it would also have been better to separate the # currently:
IA3Config(
targets_modules=["layer1", "layer2", "layer3", "layer4"],
feedforward_modules=["layer2", "layer4"],
)
# might be better:
IA3Config(
targets_modules_before=["layer1", "layer3"],
target_modules_after=["layer2", "layer4"],
) I think this could have made things more obvious, but changing it now would be backwards breaking, so we should keep it as is.
I think a crucial difference is that it is multiplied after matmul and after the non-linearity, right? This means that if we merge the IA³ weights into the normal weights, the output is incorrect because the merge does not (and cannot) take the non-linearity into account. @SumanthRH do you agree with this? If true, this would explain why the tests fail when not all modules are included in the
So this statement is not quite correct, Edit: Thinking a bit more about this, isn't our implementation even potentially incorrect for |
- correct merging of conv2d for is_feedforward=True vs False - correct merging for is_feedforward=False and Linear layer (take bias into account) - extend test examples for custom models - extend test to include merge_adapter and unmerge_adapter
Okay, I made some progress: Merging IA³ with I also noticed that we never tested I'm still getting errors for Conv1D so if you could take a look @SumanthRH that would be great. Regarding the question of whether the case |
Hi @BenjaminBossan, Regarding your comment
I believe the current implementation is fine as long as you use the right set of target/feedforward modules! IA And you're right about the bug on merging! I should have written out tests right away then! We had tested the merge functionality on a GPT-2 model, but I think for the specific fine-tuned model we used, the logits were unaffected even if you ignored merging the bias term. I have boiled down the bug to this implementation detail :
Is this even right? From what I see with the T-few implementation, this looks to be the case. T-few adds trainable parameters to activations after The solution for the merge bug would be to have different merging/unmerging depending on the |
Ah yes, I see, it's a matter of perspective I guess :) My issue, which maybe I didn't put clearly, is rather that when a user wants to use IA³ on a different model architecture with a layer not set as Also, I think my point stands that Maybe this can all be solved by better documentation, but I also think the API is not optimal. I would think that most users want
Yes, I think merging should be correct now and tests seem to support this. The only setting that is still failing is when using |
This should resolve the failing slow test test_4bit_merge_and_disable_lora. While investigating, I also noticed that merging multiple adapters was not correct for IA³. I added a test that should catch this bug and provided a fix for it too. However, the test does not check IA³ at the moment because the test parameters do not contain IA³. For this, huggingface#972 needs to be merged too, which adds IA³ to the test parameters.
* Fix issues with merging multiple adapters This should resolve the failing slow test test_4bit_merge_and_disable_lora. While investigating, I also noticed that merging multiple adapters was not correct for IA³. I added a test that should catch this bug and provided a fix for it too. However, the test does not check IA³ at the moment because the test parameters do not contain IA³. For this, #972 needs to be merged too, which adds IA³ to the test parameters. * Small adjustments to tests Previously, tests had some exploding gradients, making them unstable.
Hello, insightful discussion above. I want to add a few points on why the current API is correct.
I think the target and feedforward modules for the recently added Falcon model (in the other.py file) are incorrect as it is only targeting the attention submodule and missing the feedforward submodule which needs correction. |
- learning rate for IA³ EmbConv1D needs to be increased for changes in the output to become detectable - transpose now works correctly with nn.Parameter Note: IA³ + Conv1D tests are still commented out because they fail for different reasons (shape mismatch).
Thanks @pacman100 for investigating. To summarize for myself: The way IA³ was used in the specific transformers architectures we explicitly added, it works correctly because there is no bias. The exception is Falcon, which needs an adjustment for the target layers (could you make a PR for that?). When it comes to applying IA³ to other layers, there could be a bug when merging when the layer has a bias, which is fixed by this PR. There is still the issue of |
Was set high for testing IA³ conv1d
@pacman100 @younesbelkada I think the open questions have been discussed, so this PR should be ready for review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @BenjaminBossan for adding Conv2d layers support for IA3, adding all the tests related to it, and fixing the (un)merge when using IA3 for layers with bias 🚀. LGTM!
Just a mild concern about very high thresholds for tests when using IA3 for Conv2D layers but that can be investigated later on.
Yes, for sure. I did a manual check that the results are correct in the sense that they're highly correlated with the expected output (not just some random values). I suspect that the convolution operation, as it is applied repeatedly, accumulates errors, leading to greater total deviation. |
Adds support for
Conv2d
layers to the IA³ tuner. Tests are added to check that they work.Notes:
Unfortunately, when unmerging the
Conv2d
IA³ layers, there is quite a bit of rounding error. I had to increase the tolerances for this specific test case to make the tests pass. I'm not 100% sure why this is, but I could imagine that forConv2d
, small errors accumulate because of the convolution operation.I also added tests for IA³
Linear
layers for the custom models, which also pass. However, there is an error when usingConv1D
. The reason is that merging fails because there is a shape mismatch whenfan_in_fan_out=True
(which is set automatically forConv1D
). I'm not sure how this should be fixed. For the time being, I commented these tests.I also noticed that I don't understand what the
feedforward_modules
parameter does in IA³. AFAICT, it always has to be set for all IA³ layers, i.e.self.is_feedforward
always must beTrue
. If it cannot beFalse
, we could as well removefeedforward_modules
and automatically setis_feedforward
to True.