Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: issues with (un)merging multiple LoRA and IA³ adapters #976

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/peft/tuners/ia3/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def merge(self) -> None:
self.weight = transpose(self.weight, self.fan_in_fan_out)
self.weight.data = torch.mul(self.weight.data, self.ia3_l[active_adapter].data)
self.weight = transpose(self.weight, self.fan_in_fan_out)
self.merged_adapters.append(active_adapter)
self.merged = True

def unmerge(self) -> None:
Expand All @@ -108,7 +109,8 @@ def unmerge(self) -> None:
return

warnings.warn("Unmerge result can be inaccurate for (IA)^3.")
for active_adapter in self.active_adapters:
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.ia3_l.keys():
self.weight = transpose(self.weight, self.fan_in_fan_out)
# divide by (IA)^3 vector. Add tolerace to avoid division by zero
Expand Down
1 change: 1 addition & 0 deletions src/peft/tuners/lora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def merge(self):
lora_data = self.get_delta_weight(active_adapter)
w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state) + lora_data
self.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(self.weight.device)
self.merged_adapters.append(active_adapter)
self.merged = True

def unmerge(self):
Expand Down
54 changes: 52 additions & 2 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,15 @@ def __init__(self):
self.relu = nn.ReLU()
self.flat = nn.Flatten()
self.lin0 = nn.Linear(10, 2)
self.sm = nn.LogSoftmax(dim=-1)

def forward(self, X):
X = self.emb(X)
X = self.conv1d(X)
X = self.relu(X)
X = self.flat(X)
X = self.lin0(X)
X = self.sm(X)
return X


Expand All @@ -195,13 +197,15 @@ def __init__(self):
self.relu = nn.ReLU()
self.flat = nn.Flatten()
self.lin0 = nn.Linear(10, 2)
self.sm = nn.LogSoftmax(dim=-1)

def forward(self, X):
X = X.float().reshape(2, 5, 3, 3)
X = self.conv2d(X)
X = self.relu(X)
X = self.flat(X)
X = self.lin0(X)
X = self.sm(X)
return X


Expand Down Expand Up @@ -388,14 +392,17 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs):
outputs_before = model(**X)

model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# EmbConv1D is slow to learn for some reason
lr = 0.01 if model_id != "EmbConv1D" else 0.1
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

# train at least 3 steps for all parameters to be updated (probably this is required because of symmetry
# breaking of some LoRA layers that are initialized with constants)
for _ in range(3):
optimizer.zero_grad()
y_pred = model(**X)
loss = y_pred.sum()
y = torch.arange(len(y_pred)).to(self.torch_device) % 2
loss = nn.functional.nll_loss(y_pred, y)
loss.backward()
optimizer.step()

Expand All @@ -412,6 +419,49 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs):
self.assertTrue(torch.allclose(outputs_before, outputs_disabled))
self.assertTrue(torch.allclose(outputs_after, outputs_enabled_after_disable))

@parameterized.expand(TEST_CASES)
def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, config_kwargs):
# same as test_disable_adapters, but with merging
X = self.prepare_inputs_for_testing()
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
model = get_peft_model(model, config)
model.eval()
outputs_before = model(**X)

model.train()
# EmbConv1D is slow to learn for some reason
lr = 0.01 if model_id != "EmbConv1D" else 0.1
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

# train at least 3 steps for all parameters to be updated (probably this is required because of symmetry
# breaking of some LoRA layers that are initialized with constants)
for _ in range(3):
optimizer.zero_grad()
y_pred = model(**X)
y = torch.arange(len(y_pred)).to(self.torch_device) % 2
loss = nn.functional.nll_loss(y_pred, y)
loss.backward()
optimizer.step()

model.eval()
model.merge_adapter()
outputs_after = model(**X)

with model.disable_adapter():
outputs_disabled = model(**X)

# check that after leaving the disable_adapter context, everything is enabled again
outputs_enabled_after_disable = model(**X)

atol, rtol = 1e-5, 1e-5 # merging introduces some numerical instability
self.assertFalse(torch.allclose(outputs_before, outputs_after, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(outputs_before, outputs_disabled, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(outputs_after, outputs_enabled_after_disable, atol=atol, rtol=rtol))

@parameterized.expand(TEST_CASES)
def test_disable_adapter_with_bias_warns(self, test_name, model_id, config_cls, config_kwargs):
# When training biases in lora, disabling adapters does not reset the biases, so the output is not what users
Expand Down