Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update OFT to fix merge bugs #1996

Merged
merged 43 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
8e48df0
updating oft
Aug 7, 2024
42b72e6
update oft to be consistent with other peft methods
Aug 7, 2024
6ade80f
Merge remote-tracking branch 'upstream/main' into oft-new2
Aug 7, 2024
ada7c69
update oft to be consistent with other peft methods
Aug 7, 2024
b7f5b23
update oft to be consistent with other peft methods
Aug 8, 2024
3d9a4bc
update oft to fix merge bugs, be consistent with other peft methods
Aug 8, 2024
2bca964
update oft to fix merge bugs, be consistent with other peft methods
Aug 8, 2024
76aba70
update oft to fix merge bugs, be consistent with other peft methods
Aug 8, 2024
7b9a1af
Merge remote-tracking branch 'upstream/main' into oft-new2
Aug 8, 2024
7c116b5
addressing issues in config file
Aug 9, 2024
3297ed6
update oft config to be compatible with previous peft 0.12.1.dev0 ver…
Aug 9, 2024
fd3a7a2
update according to review
Aug 13, 2024
27fc08c
run make style
Aug 13, 2024
ddac2b1
Merge remote-tracking branch 'upstream/main' into oft-new2
Aug 20, 2024
b4637a3
added check_kwargs
Aug 26, 2024
e62bcfd
update version check
Aug 27, 2024
fa029c0
update oft config
Sep 17, 2024
d5b2b5a
Merge branch 'main' into oft-new2
Sep 17, 2024
c9d8d58
update oft comments
Sep 17, 2024
690f46f
running make style
Sep 24, 2024
dac8472
manually update for make quality
Sep 25, 2024
d763329
update from __future__ import annotations
Sep 25, 2024
2422d7d
fix import error
Sep 25, 2024
d226b09
update import
Sep 25, 2024
615ee87
Merge remote-tracking branch 'upstream/main' into oft-new2
Sep 25, 2024
74c4b27
update for passing test
Sep 25, 2024
bc149fa
update to fix the low_cpu_mem_usage error
Sep 26, 2024
eb9887f
update to fix the mixed peft errors
Sep 26, 2024
2046baa
update to fix the mixed peft errors
Sep 26, 2024
3838bdb
remove oft from mixed peft
Sep 26, 2024
bd7a4db
remove oft from mixed peft + make quality
Sep 26, 2024
2b105e3
resolve make test errors
Sep 27, 2024
b794269
update to resolve make test errors
Sep 27, 2024
65ce38b
update oft config
Sep 27, 2024
92328c1
modify to resolve issues with make test
Sep 27, 2024
e601692
update to solve test_feature_extraction_models
Sep 27, 2024
b13410c
Merge remote-tracking branch 'upstream/main' into oft-new2
Sep 27, 2024
d6fc326
fetch upstream and make style
Sep 27, 2024
c591015
updat test_deeply_nested precision
Sep 30, 2024
097b2f8
skip test_deeply_nested for remote check
Sep 30, 2024
aedaa0d
update oft check_kwargs
Oct 1, 2024
0fd23cc
update oft check_kwargs
Oct 1, 2024
64cd73f
update oft check_kwargs
Oct 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/peft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: Optional

loaded_attributes = cls.from_json_file(config_file)
kwargs = {**class_kwargs, **loaded_attributes}
kwargs = cls.check_kwargs(**kwargs)
return cls.from_peft_type(**kwargs)

@classmethod
Expand Down Expand Up @@ -213,6 +214,21 @@ def _get_peft_type(
loaded_attributes = cls.from_json_file(config_file)
return loaded_attributes["peft_type"]

@classmethod
def check_kwargs(cls, **kwargs):
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
r"""
Check if the kwargs are valid for the configuration.

Args:
kwargs (additional keyword arguments, *optional*):
Additional keyword arguments passed along to the child class initialization.
"""
if "oft_block_size" in kwargs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, so oft_block_size is the new parameter, right? So if we're loading an old OFT model, it should be missing. Therefore, should the check not be if "oft_block_size" not in kwargs?

warnings.warn(
'OFT has been updated since 0.12.1.dev0. Your trained adapter weights may not be compatible with the latest version of OFT. Please retrain your adapter weights.')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Note that dev0 is not a real version, so let's not refer to that. The next release version will be 0.13.0, so let' use that version.
  2. "may not be compatible": We are pretty sure it is incompatible when trained, right? Let's phrase it as "is incompatible".
  3. Let's also mention that users can downgrade PEFT to version 0.12.0 and then the adapter will still work.

)
return kwargs

@property
def is_prompt_learning(self) -> bool:
r"""
Expand Down
11 changes: 6 additions & 5 deletions src/peft/tuners/boft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class BOFTConfig(PeftConfig):
boft_block_num (`int`): Number of BOFT blocks per injected layer.
boft_n_butterfly_factor (`int`): Number of butterfly factors across different layers.
target_modules (`Union[List[str],str]`): The names of the modules to apply the adapter to.
boft_dropout (`float`): The multiplicative dropout probability for BOFT layers.
boft_dropout (`float`): The multiplicative dropout probability, by setting OFT blocks to identity during training, similar to the dropout layer in LoRA.
fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out).
For example, gpt-2 uses `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set
to `True`.
Expand Down Expand Up @@ -81,7 +81,7 @@ class BOFTConfig(PeftConfig):
"example": "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' ",
},
)
boft_dropout: float = field(default=0.0, metadata={"help": "BOFT multiplicative dropout"})
boft_dropout: float = field(default=0.0, metadata={"help": "BOFT multiplicative dropout, randomly setting blocks of OFT to be identity matrix, similar to the dropout layer in LoRA."})
fan_in_fan_out: bool = field(
default=False,
metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"},
Expand Down Expand Up @@ -125,9 +125,10 @@ def __post_init__(self):
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
)
if self.boft_block_size == 0 and self.boft_block_num == 0:
raise ValueError("You must specify either boft_block_size or boft_block_num.")
raise ValueError(
f"Either `boft_block_size` or `boft_block_num` must be non-zero. Currently, boft_block_size = {self.boft_block_size} and boft_block_num = {self.boft_block_num}."
)
if not (self.boft_block_size != 0) ^ (self.boft_block_num != 0):
raise ValueError(
f"You can only specify either boft_block_size ({self.boft_block_size}) or boft_block_num ({self.boft_block_num}), "
"but not both simultaneously, because boft_block_size x boft_block_num != in_features."
f"You can only specify either boft_block_size ({self.boft_block_size}) or boft_block_num ({self.boft_block_num}), but not both simultaneously, because boft_block_size x boft_block_num == in_features."
)
35 changes: 19 additions & 16 deletions src/peft/tuners/boft/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,7 @@ def update_layer(
boft_block_num = int(self.in_features // boft_block_size)

else:
raise ValueError(
f"You can only specify either boft_block_size ({boft_block_size}) or boft_block_num ({boft_block_num}), but not both simultaneously or setting both"
"to be 0, because boft_block_size x boft_block_num != in_features."
)
raise ValueError("Something went wrong, please report this error: https://github.com/huggingface/peft/issues")

# In OFT you can specify the number of blocks to be 1
if boft_n_butterfly_factor != 0:
Expand Down Expand Up @@ -711,11 +708,6 @@ def update_layer(
conv_filter_dim = self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0]

# Initialize the BOFT parameters.
if not (boft_block_size != 0) ^ (boft_block_num != 0):
raise ValueError(
f"You can only specify either boft_block_size ({boft_block_size}) or boft_block_num ({boft_block_num}), but not both simultaneously, because boft_block_size x boft_block_num != in_features."
)

if boft_block_size == 0 and boft_block_num != 0:
if conv_filter_dim % boft_block_num != 0:
raise ValueError(
Expand Down Expand Up @@ -753,7 +745,8 @@ def update_layer(
boft_block_num = int(conv_filter_dim // boft_block_size)

else:
raise ValueError("Unknown error!")
raise ValueError("Something went wrong, please report this error: https://github.com/huggingface/peft/issues")


# In OFT you can specify the number of blocks to be 1
if boft_n_butterfly_factor != 0:
Expand All @@ -777,7 +770,7 @@ def update_layer(
self.boft_R[adapter_name] = nn.Parameter(
torch.zeros(boft_n_butterfly_factor + 1, boft_block_num, boft_block_size, boft_block_size)
)
self.boft_s[adapter_name] = nn.Parameter(torch.ones(1, int(self.out_features)))
self.boft_s[adapter_name] = nn.Parameter(torch.ones(int(self.out_features), 1))
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved

self.reset_boft_parameters(adapter_name, init_weights)

Expand Down Expand Up @@ -816,9 +809,12 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
butterfly_oft_mat, boft_s = self.get_delta_weight(active_adapter)

orig_weight = orig_weight.view(
self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0], self.out_features
self.out_features,
self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0]
)
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = torch.mm(butterfly_oft_mat, orig_weight)
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = orig_weight * boft_s
orig_weight = orig_weight.view(
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0]
Expand All @@ -830,9 +826,12 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N

orig_weight = base_layer.weight.data.clone()
orig_weight = orig_weight.view(
self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0], self.out_features
self.out_features,
self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0]
)
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = torch.mm(butterfly_oft_mat, orig_weight)
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = orig_weight * boft_s
orig_weight = orig_weight.view(
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0]
Expand All @@ -856,10 +855,12 @@ def unmerge(self) -> None:

orig_weight = self.get_base_layer().weight.data.clone()
orig_weight = orig_weight.view(
self.in_features * self.get_base_layer().kernel_size[0] * self.get_base_layer().kernel_size[0],
self.out_features,
self.in_features * self.get_base_layer().kernel_size[0] * self.get_base_layer().kernel_size[0],
)
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = torch.mm(butterfly_oft_mat.t(), orig_weight)
orig_weight = torch.transpose(orig_weight, 0, 1)
orig_weight = orig_weight * (1 / boft_s)
orig_weight = orig_weight.view(
self.out_features,
Expand Down Expand Up @@ -918,7 +919,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
device=x.device,
dtype=x.dtype,
)
boft_scale = torch.ones((1, int(self.out_features)), device=x.device, dtype=x.dtype)
boft_scale = torch.ones((int(self.out_features), 1), device=x.device, dtype=x.dtype)

for active_adapter in self.active_adapters:
if active_adapter not in self.boft_R.keys():
Expand Down Expand Up @@ -955,10 +956,12 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:

orig_weight = self.base_layer.weight.data
orig_weight = orig_weight.view(
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
self.out_features,
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
)
orig_weight = torch.transpose(orig_weight, 0, 1)
rotated_weight = torch.mm(boft_rotation, orig_weight)
rotated_weight = torch.transpose(rotated_weight, 0, 1)

scaled_rotated_weight = rotated_weight * boft_scale

Expand Down
57 changes: 51 additions & 6 deletions src/peft/tuners/oft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
from dataclasses import dataclass, field
from typing import List, Optional, Union

from peft.tuners.lycoris_utils import LycorisConfig
from peft.config import PeftConfig
from peft.utils import PeftType


@dataclass
class OFTConfig(LycorisConfig):
class OFTConfig(PeftConfig):
"""
This is the configuration class to store the configuration of a [`OFTModel`].

Args:
r (`int`): OFT rank.
module_dropout (`int`): The dropout probability for disabling OFT modules during training.
r (`int`): OFT rank, number of OFT blocks per injected layer.
oft_block_size (`int`): OFT block size across different layers.
module_dropout (`float`): The multiplicative dropout probability, by setting OFT blocks to identity during training, similar to the dropout layer in LoRA.
target_modules (`Optional[Union[List[str], str]]`):
The names of the modules to apply the adapter to. If this is specified, only the modules with the specified
names will be replaced. When passing a string, a regex match will be performed. When passing a list of
Expand All @@ -35,6 +36,10 @@ class OFTConfig(LycorisConfig):
the output layer. If this is not specified, modules will be chosen according to the model architecture. If
the architecture is not known, an error will be raised -- in this case, you should specify the target
modules manually.
fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out).
bias (`str`): Bias type for OFT. Can be 'none', 'all' or 'oft_only'. If 'all' or 'oft_only', the
corresponding biases will be updated during training. Be aware that this means that, even when disabling
the adapters, the model will not produce the same output as the base model would have without adaptation.
init_weights (`bool`):
Whether to perform initialization of OFT weights.
layers_to_transform (`Union[List[int], int]`):
Expand All @@ -56,9 +61,16 @@ class OFTConfig(LycorisConfig):
Whether to share the OFT parameters between blocks or not. This is `False` by default.
"""

r: int = field(default=8, metadata={"help": "OFT rank"})
r: int = field(default=8, metadata={"help": "OFT rank, number of OFT blocks per injected layer."})
oft_block_size: int = field(
default=0,
metadata={
"help": "OFT block size across different layers.",
"note": "You can only specify either r or oft_block_size, but not both simultaneously, because r x oft_block_size = layer dimension.",
},
)
module_dropout: float = field(
default=0.0, metadata={"help": "The dropout probability for disabling OFT modules during training"}
default=0.0, metadata={"help": "OFT multiplicative dropout, randomly setting blocks of OFT to be identity matrix, similar to the dropout layer in LoRA."}
)
target_modules: Optional[Union[List[str], str]] = field(
default=None,
Expand All @@ -68,6 +80,11 @@ class OFTConfig(LycorisConfig):
"This can also be a wildcard 'all-linear' which matches all linear/Conv1D layers except the output layer."
},
)
fan_in_fan_out: bool = field(
default=False,
metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"},
)
bias: str = field(default="none", metadata={"help": "Bias type for OFT. Can be 'none', 'all' or 'oft_only'"})
init_weights: bool = field(
default=True,
metadata={
Expand Down Expand Up @@ -111,9 +128,37 @@ class OFTConfig(LycorisConfig):
default=False,
metadata={"help": "Whether to share the OFT parameters between blocks or not."},
)
rank_pattern: Optional[dict] = field(
default_factory=dict,
metadata={
"help": (
"The mapping from layer names or regexp expression to ranks which are different from the default rank specified by `r`. "
"For example, `{model.decoder.layers.0.encoder_attn.k_proj: 8`}"
"Important: the rank pattern won't be applied to the layers after 0.12.1.dev0!"
)
},
)
alpha_pattern: Optional[dict] = field(
default_factory=dict,
metadata={
"help": (
"The mapping from layer names or regexp expression to alphas which are different from the default alpha specified by `alpha`. "
"For example, `{model.decoder.layers.0.encoder_attn.k_proj: 32`}"
"Important: the alpha pattern won't be applied to the layers after 0.12.1.dev0!"
)
},
)

def __post_init__(self):
self.peft_type = PeftType.OFT
self.target_modules = (
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
)
if self.r == 0 and self.oft_block_size == 0:
raise ValueError(
f"Either `r` or `oft_block_size` must be non-zero. Currently, r = {self.r} and oft_block_size = {self.oft_block_size}."
)
if not (self.r != 0) ^ (self.oft_block_size != 0):
raise ValueError(
f"You can only specify either r ({self.r}) or oft_block_size ({self.oft_block_size}), but not both simultaneously, because r x oft_block_size == in_features."
)
Loading