Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Argument to enable bias for LoRA B #2237

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/peft/tuners/lora/aqlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,27 @@ def __init__(
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
lora_bias: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be lora_B_bias? I find that to be a bit more informative.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed I considered this. My main reasoning for going with the more generic lora_bias was that it leaves the door open for extending this argument in the future. Say, someone finds that LoRA works much better when also adding a bias to LoRA A, then we can adopt this argument to allow this too. Otherwise, we'd have to add a new argument (and we don't want to rename arguments for obvious reasons). LMK what you think of that reasoning.

Copy link
Member

@sayakpaul sayakpaul Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, we'd have to add a new argument (and we don't want to rename arguments for obvious reasons).

I think that would still be preferrable over having a single argument for controlling the bias setup for LoRAs as I think it's still in its infancy.

Later it if it becomes a common standard to add biases for both LoRA matrices we can deprecate lora_B_bias and lora_A_bias (if we introduce such an argument) to have a single argument called lora_bias.

This is where I stand, but I am not too opinionated about it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we care about reproducibility after upgrading PEFT? Then it seems detrimental to possibly merge control of A and B biases into one flag in the future and they should be separated into two flags from the start.

Otherwise, I think in terms of opportunity cost for experimentation on the user's side having two separate parameters (lora_bias_A, lora_bias_B) is better. That said, having only one parameter appears to be simpler: let the implementation decide what the current best thing is for adding biases. So if you are just someone who wants to do LoRA best-current-practice it would be helpful to only have one flag. This becomes harder with two flags since there is no obvious 'no bias at all' vs. 'best-practice' setting. If we have simplicity first (and don't care about reproducibility after upgrading) then one parameter is the way to go, I think. What's the stance here?

Ideally there would be another layer of abstraction, a more low-level abstraction, that has two bias parameters and one above that which decides what the best choice is at the moment. I.e. BaseLoRA(..., lora_bias_A, lora_bias_B) -> LoRA(..., lora_bias) .

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify, my idea is that if we want to later add the possibility for a bias for LoRA A, the option would be something like lora_bias="a", or for both, lora_bias="both". We should not change the meaning of lora_bias=True, in order to ensure reproducibility, as you mentioned.

If we find that the parameter gets overloaded, we can add the option for a sub-config, so LoraConfig(..., lora_bias=LoraBiasConfig(bias_a=True, bias_b=True, ...)).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like lora_bias should be fine for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback, I merged the PR as is.

**kwargs,
):
if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")

super().__init__()
LoraLayer.__init__(self, base_layer)

self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
lora_bias=lora_bias,
)

def forward(self, x: torch.Tensor):
# note: logic differs from default Linear because merging is not supported
Expand Down
16 changes: 15 additions & 1 deletion src/peft/tuners/lora/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,13 @@ def __init__(
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
lora_bias: bool = False,
**kwargs,
):
if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")

super().__init__()
LoraLayer.__init__(self, base_layer)

Expand All @@ -46,7 +51,16 @@ def __init__(
self.quant_linear_module = base_layer

self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
lora_bias=lora_bias,
)

def forward(self, x: torch.Tensor):
result = self.quant_linear_module(x)
Expand Down
25 changes: 25 additions & 0 deletions src/peft/tuners/lora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
lora_bias: bool = False,
**kwargs,
) -> None:
super().__init__()
Expand All @@ -56,6 +57,7 @@ def __init__(
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
lora_bias=lora_bias,
)

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
Expand Down Expand Up @@ -118,6 +120,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
self.get_base_layer().weight = bnb.nn.Int8Params(
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
).to(weight.device)
if self.lora_bias[active_adapter]:
bias_data = self.get_base_layer().bias.data = self.lora_B[active_adapter].bias
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
if safe_merge and not torch.isfinite(bias_data):
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
self.get_base_layer().bias.data = bias_data

state.reset_grads()
self.merged_adapters.append(active_adapter)

Expand Down Expand Up @@ -154,6 +164,9 @@ def unmerge(self) -> None:
self.get_base_layer().weight = bnb.nn.Int8Params(
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
).to(weight.device)

if self.lora_bias[active_adapter]:
self.get_base_layer().bias.data -= self.lora_B[active_adapter].bias
state.reset_grads()

def get_delta_weight(self, adapter):
Expand Down Expand Up @@ -298,6 +311,7 @@ def __init__(
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
lora_bias: bool = False,
**kwargs,
) -> None:
super().__init__()
Expand All @@ -313,6 +327,7 @@ def __init__(
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
lora_bias=lora_bias,
)

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
Expand Down Expand Up @@ -372,6 +387,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
kwargs["requires_grad"] = False
kwargs.pop("data", None)
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device)
if self.lora_bias[active_adapter]:
bias_data = self.get_base_layer().bias.data = self.lora_B[active_adapter].bias
if safe_merge and not torch.isfinite(bias_data):
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
self.get_base_layer().bias.data = bias_data

self.merged_adapters.append(active_adapter)

def unmerge(self) -> None:
Expand Down Expand Up @@ -407,6 +430,8 @@ def unmerge(self) -> None:
kwargs["requires_grad"] = False
kwargs.pop("data", None)
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device)
if self.lora_bias[active_adapter]:
self.get_base_layer().bias.data -= self.lora_B[active_adapter].bias

def get_delta_weight(self, adapter):
return (
Expand Down
23 changes: 23 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,10 @@ class LoraConfig(PeftConfig):
all have separate LoRA adapters attached to them.
runtime_config (`LoraRuntimeConfig`):
Runtime configurations (which are not saved or restored).
lora_bias (`bool`):
Defaults to `False`. Whether to enable the bias term for the LoRA B parameter. Typically, this should be
disabled. The main use case for this is when the LoRA weights were extracted from fully fine-tuned
parameters so the bias of those parameters can be taken into account.
"""

r: int = field(default=8, metadata={"help": "Lora attention dimension"})
Expand Down Expand Up @@ -391,6 +395,16 @@ class LoraConfig(PeftConfig):
runtime_config: LoraRuntimeConfig = field(
default_factory=LoraRuntimeConfig, metadata={"help": "Runtime configurations"}
)
lora_bias: bool = field(
default=False,
metadata={
"help": (
"Whether to enable the bias term for the LoRA B parameter. Typically, this should be disabled. The "
"main use case for this is when the LoRA weights were extracted from fully fine-tuned parameters so "
"the bias of those parameters can be taken into account."
)
},
)

def to_dict(self):
"""
Expand Down Expand Up @@ -446,6 +460,15 @@ def __post_init__(self):
elif self.init_lora_weights != "eva" and self.eva_config is not None:
warnings.warn("`eva_config` specified but will be ignored when `init_lora_weights` is not 'eva'.")

if self.lora_bias:
if self.init_lora_weights not in (True, False):
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"The argument lora_bias=True is only supported with init_lora_weights=True or False, got "
f"init_lora_weights={self.init_lora_weights} instead."
)
if self.use_dora:
raise ValueError("The argument lora_bias=True is not supported for DoRA, please pass use_dora=False")

# Using post training conversion of modified base weights to restore their initial values (PiSSA, OLoRA) cannot
# be correctly done when using rslora + rank_pattern/alpha_pattern. We can't really know if the user intends
# this when they'll eventually call save_pretrained (i.e. if they'll pass
Expand Down
16 changes: 15 additions & 1 deletion src/peft/tuners/lora/eetq.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,13 @@ def __init__(
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
lora_bias: bool = False,
**kwargs,
):
if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")

super().__init__()
LoraLayer.__init__(self, base_layer)

Expand All @@ -43,7 +48,16 @@ def __init__(
self.quant_linear_module = base_layer

self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
lora_bias=lora_bias,
)

def forward(self, x: torch.Tensor):
result = self.quant_linear_module(x)
Expand Down
1 change: 1 addition & 0 deletions src/peft/tuners/lora/eva.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def _load_eva_state_dict(
"lora_dropout": peft_config.lora_dropout,
"use_rslora": peft_config.use_rslora,
"use_dora": peft_config.use_dora,
"lora_bias": peft_config.lora_bias,
}
missing_eva_inits = []
new_target_modules = []
Expand Down
2 changes: 2 additions & 0 deletions src/peft/tuners/lora/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
lora_bias: bool = False,
**kwargs,
):
super().__init__()
Expand All @@ -52,6 +53,7 @@ def __init__(
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
lora_bias=lora_bias,
)

def forward(self, x: torch.Tensor):
Expand Down
5 changes: 5 additions & 0 deletions src/peft/tuners/lora/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ def __init__(
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
lora_bias: bool = False,
**kwargs,
) -> None:
if lora_bias:
raise ValueError(f"{self.__class__.__name__} does not support lora_bias yet, set it to False")

super().__init__()
LoraLayer.__init__(self, base_layer)
self.fan_in_fan_out = False
Expand All @@ -56,6 +60,7 @@ def __init__(
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
lora_bias=lora_bias,
)

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
Expand Down
Loading
Loading