Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Support Conv2d layers for IA³ #972

Merged
Prev Previous commit
Next Next commit
Fix some minor issues:
- learning rate for IA³ EmbConv1D needs to be increased for changes in
  the output to become detectable
- transpose now works correctly with nn.Parameter

Note: IA³ + Conv1D tests are still commented out because they fail for
different reasons (shape mismatch).
BenjaminBossan committed Oct 4, 2023
commit 806c3a9a9792319a29288f034ba06a0c63af5fd7
7 changes: 6 additions & 1 deletion src/peft/utils/other.py
Original file line number Diff line number Diff line change
@@ -331,7 +331,12 @@ def lambda_policy_fn(module):


def transpose(weight, fan_in_fan_out):
return weight.T if fan_in_fan_out else weight
if not fan_in_fan_out:
return weight

if isinstance(weight, torch.nn.Parameter):
return torch.nn.Parameter(weight.T)
return weight.T


def _is_valid_match(key: str, target_key: str):
4 changes: 2 additions & 2 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
@@ -493,7 +493,7 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs):

model.train()
# EmbConv1D is slow to learn for some reason
lr = 0.01 if model_id != "EmbConv1D" else 0.1
lr = 0.01 if model_id != "EmbConv1D" else 100.0
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

# train at least 3 steps for all parameters to be updated (probably this is required because of symmetry
@@ -534,7 +534,7 @@ def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, co

model.train()
# EmbConv1D is slow to learn for some reason
lr = 0.01 if model_id != "EmbConv1D" else 0.1
lr = 0.01 if model_id != "EmbConv1D" else 100.0
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

# train at least 3 steps for all parameters to be updated (probably this is required because of symmetry