Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the option to use the corrected scaling factor for LoRA, based on new research. #1244

Merged
merged 7 commits into from
Dec 15, 2023
7 changes: 5 additions & 2 deletions docs/source/conceptual_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ As with other methods supported by PEFT, to fine-tune a model using LoRA, you ne

- `r`: the rank of the update matrices, expressed in `int`. Lower rank results in smaller update matrices with fewer trainable parameters.
- `target_modules`: The modules (for example, attention blocks) to apply the LoRA update matrices.
- `alpha`: LoRA scaling factor.
- `lora_alpha`: LoRA scaling factor.
- `bias`: Specifies if the `bias` parameters should be trained. Can be `'none'`, `'all'` or `'lora_only'`.
- `use_rslora`: When set to True, uses <a href='https://doi.org/10.48550/arXiv.2312.03732'>Rank-Stabilized LoRA</a> which sets the adapter scaling factor to `lora_alpha/math.sqrt(r)`, since it was proven to work better. Otherwise, it will use the original default value of `lora_alpha/r`.
- `modules_to_save`: List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. These typically include model's custom head that is randomly initialized for the fine-tuning task.
- `layers_to_transform`: List of layers to be transformed by LoRA. If not specified, all layers in `target_modules` are transformed.
- `layers_pattern`: Pattern to match layer names in `target_modules`, if `layers_to_transform` is specified. By default `PeftModel` will look at common layer pattern (`layers`, `h`, `blocks`, etc.), use it for exotic and custom models.
Expand Down Expand Up @@ -111,4 +112,6 @@ lora_config = LoraConfig(..., init_lora_weights="loftq", loftq_config=loftq_conf
peft_model = get_peft_model(base_model, lora_config)
```

Finally, there is also an option to set `initialize_lora_weights=False`. When choosing this option, the LoRA weights are initialized such that they do *not* result in an identity transform. This is useful for debugging and testing purposes and should not be used otherwise.
There is also an option to set `initialize_lora_weights=False`. When choosing this option, the LoRA weights are initialized such that they do *not* result in an identity transform. This is useful for debugging and testing purposes and should not be used otherwise.

Finally, the LoRA architecture scales each adapter during every forward pass by a fixed scalar, which is set at initialization, and depends on the rank `r`. Although the original LoRA method uses the scalar function `lora_alpha/r`, the research [Rank-Stabilized LoRA](https://doi.org/10.48550/arXiv.2312.03732) proves that instead using `lora_alpha/math.sqrt(r)`, stabilizes the adapters and unlocks the increased performance potential from higher ranks. Set `use_rslora=True` to use the rank-stabilized scaling `lora_alpha/math.sqrt(r)`.
6 changes: 4 additions & 2 deletions src/peft/tuners/lora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ def __init__(
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
**kwargs,
) -> None:
super().__init__()
LoraLayer.__init__(self, base_layer)

self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)

def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
"""
Expand Down Expand Up @@ -194,12 +195,13 @@ def __init__(
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
**kwargs,
) -> None:
super().__init__()
LoraLayer.__init__(self, base_layer)

self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)

def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
"""
Expand Down
15 changes: 15 additions & 0 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ class LoraConfig(PeftConfig):
bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only'. If 'all' or 'lora_only', the
corresponding biases will be updated during training. Be aware that this means that, even when disabling
the adapters, the model will not produce the same output as the base model would have without adaptation.
use_rslora (`bool`):
When set to True, uses <a href='https://doi.org/10.48550/arXiv.2312.03732'>Rank-Stabilized LoRA</a> which
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
sets the adapter scaling factor to `lora_alpha/math.sqrt(r)`, since it was proven to work better.
Otherwise, it will use the original default value of `lora_alpha/r`.
modules_to_save (`List[str]`):List of modules apart from LoRA layers to be set as trainable
and saved in the final checkpoint.
layers_to_transform (`Union[List[int],int]`):
Expand Down Expand Up @@ -89,6 +93,17 @@ class LoraConfig(PeftConfig):
metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"},
)
bias: str = field(default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"})
use_rslora: bool = field(
default=False,
metadata={
"help": (
"When set to True, uses Rank-Stabilized LoRA doi.org/10.48550/arXiv.2312.03732"
" which sets the adapter scaling factor to `lora_alpha/math.sqrt(r)`, since it"
" was proven to work better. Otherwise, it will use the original default"
" value of `lora_alpha/r`."
)
},
)
modules_to_save: Optional[List[str]] = field(
default=None,
metadata={
Expand Down
3 changes: 2 additions & 1 deletion src/peft/tuners/lora/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
**kwargs,
):
super().__init__()
Expand All @@ -35,7 +36,7 @@ def __init__(
# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
# for backwards compatibility
self.quant_linear_module = base_layer
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)

def forward(self, x: torch.Tensor):
# note: logic differs from default Linear because merging is not supported
Expand Down
30 changes: 21 additions & 9 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
self.in_features = in_features
self.out_features = out_features

def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora):
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
self.r[adapter_name] = r
Expand All @@ -86,7 +86,10 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
if r > 0:
self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False)
self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=False)
self.scaling[adapter_name] = lora_alpha / r
if use_rslora:
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
else:
self.scaling[adapter_name] = lora_alpha / r

if init_lora_weights == "loftq":
self.loftq_init(adapter_name)
Expand All @@ -102,7 +105,7 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
self.to(weight.device)
self.set_adapter(self.active_adapters)

def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora):
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
self.r[adapter_name] = r
Expand All @@ -121,7 +124,10 @@ def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lo
padding = base_layer.padding
self.lora_A[adapter_name] = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False)
self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False)
self.scaling[adapter_name] = lora_alpha / r
if use_rslora:
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
else:
self.scaling[adapter_name] = lora_alpha / r

if init_lora_weights == "loftq":
self.loftq_init(adapter_name)
Expand All @@ -134,7 +140,7 @@ def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lo
self.to(base_layer.weight.device, dtype=weight.dtype)
self.set_adapter(self.active_adapters)

def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora):
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
self.r[adapter_name] = r
Expand All @@ -151,7 +157,10 @@ def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init
weight_B = torch.randn((self.out_features, r))
self.lora_embedding_A[adapter_name] = nn.Parameter(weight_A)
self.lora_embedding_B[adapter_name] = nn.Parameter(weight_B)
self.scaling[adapter_name] = lora_alpha / r
if use_rslora:
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
else:
self.scaling[adapter_name] = lora_alpha / r

if init_lora_weights == "loftq":
self.loftq_init(adapter_name)
Expand Down Expand Up @@ -254,14 +263,15 @@ def __init__(
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
is_target_conv_1d_layer: bool = False,
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
**kwargs,
) -> None:
super().__init__()
LoraLayer.__init__(self, base_layer, **kwargs)
self.fan_in_fan_out = fan_in_fan_out

self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
self.is_target_conv_1d_layer = is_target_conv_1d_layer

def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
Expand Down Expand Up @@ -390,13 +400,14 @@ def __init__(
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
**kwargs,
) -> None:
super().__init__()
LoraLayer.__init__(self, base_layer)

self._active_adapter = adapter_name
self.update_layer_embedding(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.update_layer_embedding(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)

def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
"""
Expand Down Expand Up @@ -533,13 +544,14 @@ def __init__(
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
**kwargs,
) -> None:
super().__init__()
LoraLayer.__init__(self, base_layer)

self._active_adapter = adapter_name
self.update_layer_conv2d(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.update_layer_conv2d(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)

def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
"""
Expand Down
4 changes: 4 additions & 0 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def _create_and_replace(
"lora_dropout": lora_config.lora_dropout,
"fan_in_fan_out": lora_config.fan_in_fan_out,
"init_lora_weights": lora_config.init_lora_weights,
"use_rslora": lora_config.use_rslora,
}
kwargs["loaded_in_8bit"] = optional_kwargs.pop("loaded_in_8bit", False)
kwargs["loaded_in_4bit"] = optional_kwargs.pop("loaded_in_4bit", False)
Expand All @@ -170,6 +171,7 @@ def _create_and_replace(
alpha,
lora_config.lora_dropout,
lora_config.init_lora_weights,
lora_config.use_rslora,
)
elif isinstance(target, Embedding):
target.update_layer_embedding(
Expand All @@ -178,6 +180,7 @@ def _create_and_replace(
alpha,
lora_config.lora_dropout,
lora_config.init_lora_weights,
lora_config.use_rslora,
)
elif isinstance(target, Linear):
target.update_layer(
Expand All @@ -186,6 +189,7 @@ def _create_and_replace(
alpha,
lora_config.lora_dropout,
lora_config.init_lora_weights,
lora_config.use_rslora,
)
else:
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
Expand Down
8 changes: 7 additions & 1 deletion src/peft/tuners/lora/tp_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False,
init_lora_weights: bool = True,
use_rslora: bool = False,
**kwargs,
):
super().__init__()
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(
lora_alpha,
lora_dropout,
init_lora_weights,
use_rslora,
init_method,
input_is_parallel,
gather_output,
Expand All @@ -67,6 +69,7 @@ def update_layer(
lora_alpha,
lora_dropout,
init_lora_weights,
use_rslora,
init_method=init.xavier_normal_,
input_is_parallel=True,
gather_output=False,
Expand Down Expand Up @@ -109,7 +112,10 @@ def update_layer(
)
self.lora_A[adapter_name] = lora_a
self.lora_B[adapter_name] = lora_b
self.scaling[adapter_name] = lora_alpha / r
if use_rslora:
self.scaling[adapter_name] = lora_alpha / (r**0.5)
else:
self.scaling[adapter_name] = lora_alpha / r
if init_lora_weights:
self.reset_lora_parameters(adapter_name, init_lora_weights)

Expand Down
86 changes: 86 additions & 0 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,89 @@ def test_lora_conv2d_false(self):
# with init_lora_weights=False, weight B should *not* be zero. We don't care so much about the actual values
# as long as they are not zero, in order to avoid identity transformation.
self.assertFalse(torch.allclose(weight_B, torch.zeros_like(weight_B)))

def test_lora_scaling_default(self):
# default is True
torch.manual_seed(0)

model = self.get_model()

# check scaling factor use_rslora=False
config = LoraConfig(target_modules=["linear", "embed", "conv2d"], lora_alpha=3, r=16, use_rslora=False)
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved
model = get_peft_model(model, config)

expected_scaling = config.lora_alpha / config.r

self.assertTrue(model.linear.scaling["default"] == expected_scaling)
self.assertTrue(model.embed.scaling["default"] == expected_scaling)
self.assertTrue(model.conv2d.scaling["default"] == expected_scaling)

def test_rslora_scaling(self):
# default is True
torch.manual_seed(0)

model = self.get_model()

# check scaling factor use_rslora=True
config = LoraConfig(target_modules=["linear", "embed", "conv2d"], lora_alpha=3, r=16, use_rslora=True)
model = get_peft_model(model, config)

expected_scaling = config.lora_alpha / (config.r**0.5)

self.assertTrue(model.linear.scaling["default"] == expected_scaling)
self.assertTrue(model.embed.scaling["default"] == expected_scaling)
self.assertTrue(model.conv2d.scaling["default"] == expected_scaling)

def test_lora_default_scaling_pattern(self):
# default is True
torch.manual_seed(0)

model = self.get_model()

# check scaling factor use_rslora=False with rank and alpha pattern
config = LoraConfig(
target_modules=["linear", "embed", "conv2d"],
rank_pattern={"embed": 9, "conv2d": 16},
alpha_pattern={"linear": 11, "conv2d": 13},
lora_alpha=17,
r=25,
use_rslora=False,
)
model = get_peft_model(model, config)

expected_scaling = {
"linear": config.alpha_pattern["linear"] / config.r,
"embed": config.lora_alpha / config.rank_pattern["embed"],
"conv2d": config.alpha_pattern["conv2d"] / config.rank_pattern["conv2d"],
}

self.assertTrue(model.linear.scaling["default"] == expected_scaling["linear"])
self.assertTrue(model.embed.scaling["default"] == expected_scaling["embed"])
self.assertTrue(model.conv2d.scaling["default"] == expected_scaling["conv2d"])

def test_rslora_scaling_pattern(self):
# default is True
torch.manual_seed(0)

model = self.get_model()

# check scaling factor use_rslora=True with rank and alpha pattern
config = LoraConfig(
target_modules=["linear", "embed", "conv2d"],
rank_pattern={"embed": 9, "conv2d": 16},
alpha_pattern={"linear": 11, "conv2d": 13},
lora_alpha=17,
r=25,
use_rslora=True,
)
model = get_peft_model(model, config)

expected_scaling = {
"linear": config.alpha_pattern["linear"] / (config.r**0.5),
"embed": config.lora_alpha / (config.rank_pattern["embed"] ** 0.5),
"conv2d": config.alpha_pattern["conv2d"] / (config.rank_pattern["conv2d"] ** 0.5),
}

self.assertTrue(model.linear.scaling["default"] == expected_scaling["linear"])
self.assertTrue(model.embed.scaling["default"] == expected_scaling["embed"])
self.assertTrue(model.conv2d.scaling["default"] == expected_scaling["conv2d"])
Loading