Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update OFT to fix merge bugs #1996

Merged
merged 43 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
8e48df0
updating oft
Aug 7, 2024
42b72e6
update oft to be consistent with other peft methods
Aug 7, 2024
6ade80f
Merge remote-tracking branch 'upstream/main' into oft-new2
Aug 7, 2024
ada7c69
update oft to be consistent with other peft methods
Aug 7, 2024
b7f5b23
update oft to be consistent with other peft methods
Aug 8, 2024
3d9a4bc
update oft to fix merge bugs, be consistent with other peft methods
Aug 8, 2024
2bca964
update oft to fix merge bugs, be consistent with other peft methods
Aug 8, 2024
76aba70
update oft to fix merge bugs, be consistent with other peft methods
Aug 8, 2024
7b9a1af
Merge remote-tracking branch 'upstream/main' into oft-new2
Aug 8, 2024
7c116b5
addressing issues in config file
Aug 9, 2024
3297ed6
update oft config to be compatible with previous peft 0.12.1.dev0 ver…
Aug 9, 2024
fd3a7a2
update according to review
Aug 13, 2024
27fc08c
run make style
Aug 13, 2024
ddac2b1
Merge remote-tracking branch 'upstream/main' into oft-new2
Aug 20, 2024
b4637a3
added check_kwargs
Aug 26, 2024
e62bcfd
update version check
Aug 27, 2024
fa029c0
update oft config
Sep 17, 2024
d5b2b5a
Merge branch 'main' into oft-new2
Sep 17, 2024
c9d8d58
update oft comments
Sep 17, 2024
690f46f
running make style
Sep 24, 2024
dac8472
manually update for make quality
Sep 25, 2024
d763329
update from __future__ import annotations
Sep 25, 2024
2422d7d
fix import error
Sep 25, 2024
d226b09
update import
Sep 25, 2024
615ee87
Merge remote-tracking branch 'upstream/main' into oft-new2
Sep 25, 2024
74c4b27
update for passing test
Sep 25, 2024
bc149fa
update to fix the low_cpu_mem_usage error
Sep 26, 2024
eb9887f
update to fix the mixed peft errors
Sep 26, 2024
2046baa
update to fix the mixed peft errors
Sep 26, 2024
3838bdb
remove oft from mixed peft
Sep 26, 2024
bd7a4db
remove oft from mixed peft + make quality
Sep 26, 2024
2b105e3
resolve make test errors
Sep 27, 2024
b794269
update to resolve make test errors
Sep 27, 2024
65ce38b
update oft config
Sep 27, 2024
92328c1
modify to resolve issues with make test
Sep 27, 2024
e601692
update to solve test_feature_extraction_models
Sep 27, 2024
b13410c
Merge remote-tracking branch 'upstream/main' into oft-new2
Sep 27, 2024
d6fc326
fetch upstream and make style
Sep 27, 2024
c591015
updat test_deeply_nested precision
Sep 30, 2024
097b2f8
skip test_deeply_nested for remote check
Sep 30, 2024
aedaa0d
update oft check_kwargs
Oct 1, 2024
0fd23cc
update oft check_kwargs
Oct 1, 2024
64cd73f
update oft check_kwargs
Oct 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions src/peft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,17 +216,11 @@ def _get_peft_type(

@classmethod
def check_kwargs(cls, **kwargs):
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
r"""
Check if the kwargs are valid for the configuration.
"""Check kwargs before initializing the config instance.

Subclasses can override this method to add specific checks.

Args:
kwargs (additional keyword arguments, *optional*):
Additional keyword arguments passed along to the child class initialization.
"""
if "oft_block_size" in kwargs:
warnings.warn(
'OFT has been updated since 0.12.1.dev0. Your trained adapter weights may not be compatible with the latest version of OFT. Please retrain your adapter weights.')
)
return kwargs

@property
Expand Down
19 changes: 19 additions & 0 deletions src/peft/tuners/oft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from peft.config import PeftConfig
from peft.utils import PeftType

import warnings


@dataclass
class OFTConfig(PeftConfig):
Expand Down Expand Up @@ -162,3 +164,20 @@ def __post_init__(self):
raise ValueError(
f"You can only specify either r ({self.r}) or oft_block_size ({self.oft_block_size}), but not both simultaneously, because r x oft_block_size == in_features."
)

@classmethod
def check_kwargs(cls, **kwargs):
r"""
Check if the kwargs are valid for the configuration.

Args:
kwargs (additional keyword arguments, *optional*):
Additional keyword arguments passed along to the child class initialization.
"""
if "oft_block_size" not in kwargs:
raise ValueError(
'OFT has been updated since PEFT 0.13.0. Your trained adapter weights is incompatible with the latest version of OFT. Please retrain your adapter weights with newer PEFT versions.'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
'OFT has been updated since PEFT 0.13.0. Your trained adapter weights is incompatible with the latest version of OFT. Please retrain your adapter weights with newer PEFT versions.'
'OFT has been updated since PEFT 0.13.0. Your trained adapter weights are incompatible with the latest version of OFT. Please retrain your adapter weights with newer PEFT versions.'

Also, let's ensure the 120 char line limit.

'Downgrade PEFT to version 0.12.0 to merge the old adapter weights.'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
'Downgrade PEFT to version 0.12.0 to merge the old adapter weights.'
'Alternatively, downgrade PEFT to version 0.12.0 to use the old adapter weights.'

Not specific to merging, right?

)
return super().check_kwargs(**kwargs)

48 changes: 36 additions & 12 deletions src/peft/tuners/oft/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,16 @@ def update_layer(
self.oft_dropout.update(nn.ModuleDict({adapter_name: oft_dropout_layer}))

if r == 0 and oft_block_size != 0:
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
if self.in_features % oft_block_size != 0:
raise ValueError(f"Input features ({self.in_features}) should be divisible by `oft_block_size` ({oft_block_size})")
if self.in_features % oft_block_size != 0 or oft_block_size > self.in_features:
warnings.warn(f"Invalid `oft_block_size` ({oft_block_size})!")
oft_block_size = self.adjust_oft_parameters(self.in_features, oft_block_size)
warnings.warn(f"Adjusted `oft_block_size` to ({oft_block_size}).")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of two warnings, let's make this one warning by prepending the first message to the second.

r = int(self.in_features // oft_block_size)
elif r != 0 and oft_block_size == 0:
if self.in_features % r != 0:
raise ValueError(f"Input features ({self.in_features}) should be divisible by `r` ({r})!")
if self.in_features % r != 0 or r > self.in_features:
warnings.warn(f"Invalid `r` ({r})!")
r = self.adjust_oft_parameters(self.in_features, r)
warnings.warn(f"Adjusted `r` to ({r}).")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same about having a single warning.

oft_block_size = int(self.in_features // r)
else:
raise ValueError("Something went wrong, please report this error: https://github.com/huggingface/peft/issues")
Expand Down Expand Up @@ -263,6 +267,26 @@ def _project_batch(self, oft_r, eps=1e-5):
out = torch.where(mask, oft_r, I + eps * (diff / norm_diff))
return out

def adjust_oft_parameters(self, in_features, params):
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
"""
Adjust the OFT parameters to be divisible by the in_features dimension.
"""
if params < in_features:
higher_params = params
while higher_params <= in_features and in_features % higher_params != 0:
higher_params += 1
else:
return in_features

lower_params = params
while lower_params > 1 and in_features % lower_params != 0:
lower_params -= 1

if (params - lower_params) <= (higher_params - params):
return lower_params
else:
return higher_params


class Linear(nn.Module, OFTLayer):
"""OFT implemented in Linear layer"""
Expand Down Expand Up @@ -490,16 +514,16 @@ def update_layer(
conv_filter_dim = self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0]

if r == 0 and oft_block_size != 0:
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
if conv_filter_dim % oft_block_size != 0:
raise ValueError(
f"Convolutional kernel dimension ({conv_filter_dim}) must be divisible by conv_filter_dim ({conv_filter_dim})!"
)
if conv_filter_dim % oft_block_size != 0 or oft_block_size > conv_filter_dim:
warnings.warn(f"Invalid `oft_block_size` ({oft_block_size})!")
oft_block_size = self.adjust_oft_parameters(conv_filter_dim, oft_block_size)
warnings.warn(f"Adjusted `oft_block_size` to ({oft_block_size}).")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about having a single warning.

r = int(conv_filter_dim // oft_block_size)
elif r != 0 and oft_block_size == 0:
if conv_filter_dim % r != 0:
raise ValueError(
f"Convolutional kernel dimension ({conv_filter_dim}) must be divisible by r ({r})!"
)
if conv_filter_dim % r != 0 or r > conv_filter_dim:
warnings.warn(f"Invalid `r` ({r})!")
r = self.adjust_oft_parameters(conv_filter_dim, r)
warnings.warn(f"Adjusted `r` to ({r}).")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about having a single warning.

oft_block_size = int(conv_filter_dim // r)
else:
raise ValueError("Something went wrong, please report this error: https://github.com/huggingface/peft/issues")
Expand Down