From e9f1fc64f697cc4b73f8f6339d2cd627f14fbbf6 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 29 Sep 2023 14:05:50 +0200 Subject: [PATCH 1/2] Fix issues with merging multiple adapters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This should resolve the failing slow test test_4bit_merge_and_disable_lora. While investigating, I also noticed that merging multiple adapters was not correct for IA³. I added a test that should catch this bug and provided a fix for it too. However, the test does not check IA³ at the moment because the test parameters do not contain IA³. For this, #972 needs to be merged too, which adds IA³ to the test parameters. --- src/peft/tuners/ia3/layer.py | 4 +++- src/peft/tuners/lora/bnb.py | 1 + tests/test_custom_models.py | 39 ++++++++++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/src/peft/tuners/ia3/layer.py b/src/peft/tuners/ia3/layer.py index 0012fb6ba6..80977a567b 100644 --- a/src/peft/tuners/ia3/layer.py +++ b/src/peft/tuners/ia3/layer.py @@ -100,6 +100,7 @@ def merge(self) -> None: self.weight = transpose(self.weight, self.fan_in_fan_out) self.weight.data = torch.mul(self.weight.data, self.ia3_l[active_adapter].data) self.weight = transpose(self.weight, self.fan_in_fan_out) + self.merged_adapters.append(active_adapter) self.merged = True def unmerge(self) -> None: @@ -108,7 +109,8 @@ def unmerge(self) -> None: return warnings.warn("Unmerge result can be inaccurate for (IA)^3.") - for active_adapter in self.active_adapters: + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() if active_adapter in self.ia3_l.keys(): self.weight = transpose(self.weight, self.fan_in_fan_out) # divide by (IA)^3 vector. Add tolerace to avoid division by zero diff --git a/src/peft/tuners/lora/bnb.py b/src/peft/tuners/lora/bnb.py index 3007564ad5..d6e17f8785 100644 --- a/src/peft/tuners/lora/bnb.py +++ b/src/peft/tuners/lora/bnb.py @@ -214,6 +214,7 @@ def merge(self): lora_data = self.get_delta_weight(active_adapter) w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state) + lora_data self.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(self.weight.device) + self.merged_adapters.append(active_adapter) self.merged = True def unmerge(self): diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 6f6cc1b51b..ba990b798b 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -412,6 +412,45 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs): self.assertTrue(torch.allclose(outputs_before, outputs_disabled)) self.assertTrue(torch.allclose(outputs_after, outputs_enabled_after_disable)) + @parameterized.expand(TEST_CASES) + def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, config_kwargs): + # same as test_disable_adapters, but with merging + X = self.prepare_inputs_for_testing() + model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config) + model.eval() + outputs_before = model(**X) + + model.train() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + # train at least 3 steps for all parameters to be updated (probably this is required because of symmetry + # breaking of some LoRA layers that are initialized with constants) + for _ in range(3): + optimizer.zero_grad() + y_pred = model(**X) + loss = y_pred.sum() + loss.backward() + optimizer.step() + + model.merge_adapter() + model.eval() + outputs_after = model(**X) + + with model.disable_adapter(): + outputs_disabled = model(**X) + + # check that after leaving the disable_adapter context, everything is enabled again + outputs_enabled_after_disable = model(**X) + + self.assertFalse(torch.allclose(outputs_before, outputs_after)) + self.assertTrue(torch.allclose(outputs_before, outputs_disabled)) + self.assertTrue(torch.allclose(outputs_after, outputs_enabled_after_disable)) + @parameterized.expand(TEST_CASES) def test_disable_adapter_with_bias_warns(self, test_name, model_id, config_cls, config_kwargs): # When training biases in lora, disabling adapters does not reset the biases, so the output is not what users From 3784849812e100285094f1bcb86f7225e525d192 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 29 Sep 2023 15:30:16 +0200 Subject: [PATCH 2/2] Small adjustments to tests Previously, tests had some exploding gradients, making them unstable. --- tests/test_custom_models.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index ba990b798b..0c799c6ec4 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -178,6 +178,7 @@ def __init__(self): self.relu = nn.ReLU() self.flat = nn.Flatten() self.lin0 = nn.Linear(10, 2) + self.sm = nn.LogSoftmax(dim=-1) def forward(self, X): X = self.emb(X) @@ -185,6 +186,7 @@ def forward(self, X): X = self.relu(X) X = self.flat(X) X = self.lin0(X) + X = self.sm(X) return X @@ -195,6 +197,7 @@ def __init__(self): self.relu = nn.ReLU() self.flat = nn.Flatten() self.lin0 = nn.Linear(10, 2) + self.sm = nn.LogSoftmax(dim=-1) def forward(self, X): X = X.float().reshape(2, 5, 3, 3) @@ -202,6 +205,7 @@ def forward(self, X): X = self.relu(X) X = self.flat(X) X = self.lin0(X) + X = self.sm(X) return X @@ -388,14 +392,17 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs): outputs_before = model(**X) model.train() - optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + # EmbConv1D is slow to learn for some reason + lr = 0.01 if model_id != "EmbConv1D" else 0.1 + optimizer = torch.optim.SGD(model.parameters(), lr=lr) # train at least 3 steps for all parameters to be updated (probably this is required because of symmetry # breaking of some LoRA layers that are initialized with constants) for _ in range(3): optimizer.zero_grad() y_pred = model(**X) - loss = y_pred.sum() + y = torch.arange(len(y_pred)).to(self.torch_device) % 2 + loss = nn.functional.nll_loss(y_pred, y) loss.backward() optimizer.step() @@ -426,19 +433,22 @@ def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, co outputs_before = model(**X) model.train() - optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + # EmbConv1D is slow to learn for some reason + lr = 0.01 if model_id != "EmbConv1D" else 0.1 + optimizer = torch.optim.SGD(model.parameters(), lr=lr) # train at least 3 steps for all parameters to be updated (probably this is required because of symmetry # breaking of some LoRA layers that are initialized with constants) for _ in range(3): optimizer.zero_grad() y_pred = model(**X) - loss = y_pred.sum() + y = torch.arange(len(y_pred)).to(self.torch_device) % 2 + loss = nn.functional.nll_loss(y_pred, y) loss.backward() optimizer.step() - model.merge_adapter() model.eval() + model.merge_adapter() outputs_after = model(**X) with model.disable_adapter(): @@ -447,9 +457,10 @@ def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, co # check that after leaving the disable_adapter context, everything is enabled again outputs_enabled_after_disable = model(**X) - self.assertFalse(torch.allclose(outputs_before, outputs_after)) - self.assertTrue(torch.allclose(outputs_before, outputs_disabled)) - self.assertTrue(torch.allclose(outputs_after, outputs_enabled_after_disable)) + atol, rtol = 1e-5, 1e-5 # merging introduces some numerical instability + self.assertFalse(torch.allclose(outputs_before, outputs_after, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(outputs_before, outputs_disabled, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(outputs_after, outputs_enabled_after_disable, atol=atol, rtol=rtol)) @parameterized.expand(TEST_CASES) def test_disable_adapter_with_bias_warns(self, test_name, model_id, config_cls, config_kwargs):