-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Conv1D merge error for IA3 #1014
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty good to me, thanks for identifying the issue and providing a fix that mirrors what we have for LoRA.
The only issue I have is with the choice of tolerances for tests. I see your argument but I think it's potentially dangerous to use low tolerance (high precision) to check that 2 outputs are different, and then high tolerance (low precision) to check that 2 outputs are identical. That way, we could theoretically proof that the same outputs are both identical and different at the same time!
import torch
x = torch.tensor([0.0001])
y = torch.tensor([0.0002])
# x and y are different
assert not torch.allclose(x, y, atol=1e-5)
# x and y are the same
assert torch.allclose(x, y, atol=1e-3)
Low tolerance to check for differences and high tolerance to check for identity would be okay, but like this, it's an issue IMO.
I feel your pain in trying to balance the tolerances just right so that they work for all tests, this is very fiddly. The real reason why this issue occurs is something else though, so we may be able to find a solution that avoids tuning tolerances. From my observation, the real problem is that the conv1d architecture is very slow to train, requiring a higher learning rate/number of epochs to move the needle. I'm not sure why that is, honestly.
I tried addressing this problem by increasing the learning rate for that model, but by itself, that's still not enough. Some further testing showed, however, that switching from SGD to Adam seems to do the trick. Could you please check if that works for you as well? I still think there must be some fundamental reason why that model is so slow to learn, but I couldn't figure it out yet.
Okay small update @BenjaminBossan : I tried to dig deeper into the issue with slow convergence of EmbConv1D. Fyi PR is not ready yet. I believe there are two reasons why EmbConv1D + IA3 is giving these issues:
from peft import get_peft_model, IA3Config, LoraConfig
import torch
import torch.nn as nn
from transformers.pytorch_utils import Conv1D
import math
class Conv1DNewInit(Conv1D):
"""
1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
Basically works like a linear layer but the weights are transposed.
This is a modified version with uniform initialization as in nn.Linear.
Args:
nf (`int`): The number of output features.
nx (`int`): The number of input features.
"""
def __init__(self, nf, nx):
super().__init__(nf, nx)
self.nf = nf
self.weight = nn.Parameter(torch.empty(nx, nf))
self.bias = nn.Parameter(torch.zeros(nf))
# nn.init.normal_(self.weight, std=0.02) # older init
nn.init.uniform_(self.weight, -math.sqrt(1/nx), math.sqrt(1/nx)) # new init as in nn.Linear
nn.init.uniform_(self.bias, -math.sqrt(1/nx), math.sqrt(1/nx)) # new init as in nn.Linear
def forward(self, x):
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(size_out)
return x
class ModelEmbConv1D(nn.Module):
def __init__(self):
super().__init__()
self.emb = nn.Embedding(100, 5)
self.conv1d = Conv1D(1, 5)
self.relu = nn.ReLU()
self.flat = nn.Flatten()
self.lin0 = nn.Linear(10, 2)
self.sm = nn.LogSoftmax(dim=-1)
def forward(self, X):
X = self.emb(X)
X = self.conv1d(X)
X = self.relu(X)
X = self.flat(X)
X = self.lin0(X)
X = self.sm(X)
return X
class ModelEmbConv1DNewInit(nn.Module):
def __init__(self):
super().__init__()
self.emb = nn.Embedding(100, 5)
self.conv1d = Conv1DNewInit(1, 5)
self.relu = nn.ReLU()
self.flat = nn.Flatten()
self.lin0 = nn.Linear(10, 2)
self.sm = nn.LogSoftmax(dim=-1)
def forward(self, X):
X = self.emb(X)
X = self.conv1d(X)
X = self.relu(X)
X = self.flat(X)
X = self.lin0(X)
X = self.sm(X)
return X
class ModelEmbLinear(nn.Module):
def __init__(self):
super().__init__()
self.emb = nn.Embedding(100, 5)
self.conv1d = nn.Linear(5, 1)
self.relu = nn.ReLU()
self.flat = nn.Flatten()
self.lin0 = nn.Linear(10, 2)
self.sm = nn.LogSoftmax(dim=-1)
def forward(self, X):
X = self.emb(X)
X = self.conv1d(X)
X = self.relu(X)
X = self.flat(X)
X = self.lin0(X)
X = self.sm(X)
return X
class MockTransformerWrapper:
"""Mock class to behave like a transformers model.
This is needed because the tests initialize the model by calling transformers_class.from_pretrained.
"""
@classmethod
def from_pretrained(cls, model_id):
# set the seed so that from_pretrained always returns the same model
torch.manual_seed(0)
if model_id == "EmbConv1D":
return ModelEmbConv1D()
if model_id == "EmbLinear":
return ModelEmbLinear()
if model_id == "EmbConv1DNewInit":
return ModelEmbConv1DNewInit()
raise ValueError(f"model_id {model_id} not implemented")
all_configs = [
("EmbConv1D", IA3Config, {"target_modules": ["conv1d"], "feedforward_modules": ["conv1d"]}),
("EmbConv1DNewInit", IA3Config, {"target_modules": ["conv1d"], "feedforward_modules": ["conv1d"]}),
("EmbLinear", IA3Config, {"target_modules": ["conv1d"], "feedforward_modules": ["conv1d"]}),
]
transformers_class = MockTransformerWrapper
all_outputs_after = []
all_outputs_before = []
X = {"X": torch.arange(90).view(9, 10)}
for model_id, config_cls, config_kwargs in all_configs:
model = transformers_class.from_pretrained(model_id)
print(model_id)
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
model = get_peft_model(model, config)
model.print_trainable_parameters()
model.eval()
outputs_before = model(**X)
model.train()
lr = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# train at least 3 steps for all parameters to be updated (probably this is required because of symmetry
# breaking of some LoRA layers that are initialized with constants)
for _ in range(3):
optimizer.zero_grad()
y_pred = model(**X)
y = torch.arange(len(y_pred)) % 2
loss = nn.functional.nll_loss(y_pred, y)
loss.backward()
optimizer.step()
model.eval()
outputs_after = model(**X)
with model.disable_adapter():
outputs_disabled = model(**X)
# check that after leaving the disable_adapter context, everything is enabled again
outputs_enabled_after_disable = model(**X)
all_outputs_after.append(outputs_after)
all_outputs_before.append(outputs_before)
atol = 1e-5
rtol = 1e-5
print(torch.allclose(outputs_before, outputs_after, atol=atol, rtol=rtol))
print(torch.allclose(outputs_before, outputs_disabled))
print(torch.allclose(outputs_after, outputs_enabled_after_disable)) Print results:
Optimizer: Because of 1. and 2. optimizer choice matters. As you mentioned Adam does better than SGD, and all tests pass for Conv1D if you switch to Adam. No tolerance hacks needed. Since initalization isn't something to hack around with, I tried to see if number of parameters can improve performance.I didn't find much of an improvement after going to weight size of 20x1 (20 trainable parameters since you get a 20x1 size vector). I've switched to using Adam, and this should be fine imo, since we know that the core issue is not with IA3. I'm currently working on fixing the Conv2D tests - if you use tolerances of 1e-5, Conv2D + IA3 checks start failing, even with Adam. I need to go more into this and see if we can avoid increasing the tolerance (simply using tolerance of 1e-3 for Conv2d works as of now). Btw, at this scale of tiny models, it does look like different PEFT configs can be a little tricky to manage. For example, even LoRA with config kwargs |
Thanks for digging deeper into this. I agree that with these toy examples, issues may occur that don't manifest in real problems. We could make the examples more realistic, but then tests would slow down, so it's a trade-off. Since it's just a toy test, I'm fine with having different tolerances for different models, like 1e-3 for conv1d and 1e-5 for everything else. Is that, together with using Adam, enough to make the tests pass? |
For Conv1D, we don't need to touch the tolerances anymore, even for IA3. Conv2D + IA3 is still a problem, and we need a tolerance of 1e-3 to make sure tests pass. I've made that change now. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the hard work you put into this. It looks good now.
What does this PR do?
Fixes the error after merging a Conv1D layer, as described in #972 (comment) . Currently, PEFT subclasses nn.Linear and injects trainable parameters. For Conv1D layers, we use
fan_in_fan_out
to make note of the fact that the weights need to be transposed for correct behaviour. Thus, this Linear layer has to be converted back into a Conv1D layer when you merge and convert back into the base model, but this wasn't happening before with IA3. This meant that the base model returned bymerge_and_unload
had all Conv1D layers replaced by Linear layers. I've used the LoRA implementation as a reference to fix this, and added a flagis_target_conv1d_layer
to IA3's Linear layer. @BenjaminBossan I've simply uncommented all your tests for Conv1D merging, they seem comprehensive for now.Also, there was an error being raised with
test_disable_adapters_with_merging
for IA3 - theassertFalse
checks were failing because the tolerance used was too high. The high tolerance makes sense for the merge related checks, because IA3 introduces instability since you lose some information when the vectors are multiplied with the weights. Now, theassertFalse
check is for successful training. The outputs before and after a few training steps are gonna be very close to each other, so adding a high tolerance fortorch.allclose
leads to a True output i.e it says that the outputs before and after training are indeed close to each other, and theassertFalse
fails. I think we should use the same tolerance for IA3 as for LoRA, so I've made that change. All tests pass now.