Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for an optional initialization strategy OLoRA #1828

Merged
merged 11 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/developer_guides/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ lora_config = LoraConfig(init_lora_weights="pissa_niter_[number of iters]", ...)
```
For detailed instruction on using PiSSA, please follow [these instructions](https://github.com/fxmeng/peft/tree/main/examples/pissa_finetuning).

### OLoRA
[OLoRA](https://arxiv.org/abs/2406.01775) initializes the LoRA adapter using QR decomposition, which significantly improves stability, accelerates convergence speed, and achieves superior performance.
You just need to pass a single additional option to use OLoRA:
```python
from peft import LoraConfig
config = LoraConfig(init_lora_weights="olora", ...)
```

### LoftQ

#### Standard approach
Expand Down
81 changes: 81 additions & 0 deletions examples/olora_finetuning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# OLoRA: Orthonormal Low Rank Adaptation of Large Language Models

## Introduction
[OLoRA](https://arxiv.org/abs/2406.01775) is a novel approach that leverages orthonormal low rank adaptation through QR decomposition. Unlike the default LoRA implementation, OLoRA decomposes original weights into their $\mathbf{Q}$ and $\mathbf{R}$ parts, and then uses the first `rank` rows of $\mathbf{R}$ and the first `rank` columns of $\mathbf{Q}$ to initialize $\mathbf{A}$ and $\mathbf{B}$, respectively. This results in significantly faster convergence, more stable training, and superior performance.

## Quick start
```python
import torch
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer
from datasets import load_dataset

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
dataset = load_dataset("imdb", split="train[:1%]")
lora_config = LoraConfig(
init_lora_weights="olora"
)
peft_model = get_peft_model(model, lora_config)
trainer = SFTTrainer(
tokenizer-decode marked this conversation as resolved.
Show resolved Hide resolved
model=peft_model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
tokenizer=tokenizer,
)
trainer.train()
peft_model.save_pretrained("olora-opt-350m")
```

There is no additional change needed to your standard LoRA procedure, except for specifying `init_lora_weights = "olora"` option in your lora configuration.

Additionally you can refer to olora finetuning script.
tokenizer-decode marked this conversation as resolved.
Show resolved Hide resolved
Run the script simply by running:
```bash
python3 examples/olora_finetuning/olora_finetuning.py --base_model facebook/opt-350m
```
OLoRA also supports quantization. To use 4-bit quantization try:
```bash
python3 examples/olora_finetuning/olora_finetuning.py --base_model facebook/opt-350m --quantize
```


## Use the model
You can load and use the model as any other 🤗 PEFT model
```python
from peft import PeftModel
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
olora_model = PeftModel.from_pretrained(model, "olora-opt-350m")
print(olora_model)
```

## OLoRA and LoRA
Unlike LoRA, OLoRA mutates the original weights. If you want to use multiple adapters simultaneously you can use `path_initial_model_for_weight_conversion` option. First save your model that was initialized with OLoRA **before** performing any training:
```python
init_path = <path-to-untrained-olora-model>
olora_model.save_pretrained(init_path)
tokenizer-decode marked this conversation as resolved.
Show resolved Hide resolved
```
# After training
tokenizer-decode marked this conversation as resolved.
Show resolved Hide resolved
Then you can specify the path of the initialized adapter in the `save_pretrained` method:
```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
olora_model = get_peft_model(model, lora_config)
tokenizer-decode marked this conversation as resolved.
Show resolved Hide resolved
olora_model.save_pretrained(save_dir, path_initial_model_for_weight_conversion=init_path)
```
Thus converting OLoRA to LoRA to use with multiple adapters.

## Citation
```
@misc{büyükakyüz2024olora,
title={OLoRA: Orthonormal Low-Rank Adaptation of Large Language Models},
author={Kerim Büyükakyüz},
year={2024},
eprint={2406.01775},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
tokenizer-decode marked this conversation as resolved.
Show resolved Hide resolved
148 changes: 148 additions & 0 deletions examples/olora_finetuning/olora_finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright 2023-present the HuggingFace Inc. team.
tokenizer-decode marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import List
tokenizer-decode marked this conversation as resolved.
Show resolved Hide resolved

import fire
import torch
import transformers
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from peft import (
LoraConfig,
get_peft_model,
)


def train(
base_model: str = "path/to/model",
data_path: str = "yahma/alpaca-cleaned",
output_dir: str = "olora",
batch_size: int = 16,
micro_batch_size: int = 4,
num_epochs: int = 1,
learning_rate: float = 3e-4,
cutoff_len: int = 256,
val_set_size: int = 16,
quantize: bool = False,
eval_step: int = 100,
save_step: int = 100,
device_map: str = "auto",
lora_r: int = 32,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_target_modules: List[str] = None,
init_lora_weights="olora",
):
gradient_accumulation_steps = batch_size // micro_batch_size
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map=device_map,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
if quantize
else None,
torch_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)

def tokenize(prompt, add_eos_token=True):
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)

result["labels"] = result["input_ids"].copy()

return result

def generate_and_tokenize_prompt(example):
full_prompt = generate_prompt(example)
tokenized_full_prompt = tokenize(full_prompt)
return tokenized_full_prompt

config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
init_lora_weights=init_lora_weights,
)
model = get_peft_model(model, config)

data = load_dataset(data_path)

train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)

trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=100,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
fp16=True,
logging_steps=100,
optim="adamw_torch",
evaluation_strategy="steps",
save_strategy="steps",
eval_steps=eval_step,
save_steps=save_step,
output_dir=output_dir,
save_total_limit=3,
load_best_model_at_end=True,
),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
trainer.train()
model.save_pretrained(output_dir)


def generate_prompt(example):
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{example["instruction"]}
### Response:
{example["output"]}"""


if __name__ == "__main__":
torch.manual_seed(42)
fire.Fire(train)
63 changes: 40 additions & 23 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def save_pretrained(
save_embedding_layers: Union[str, bool] = "auto",
is_main_process: bool = True,
convert_pissa_to_lora: Optional[str] = None,
path_initial_model_for_weight_conversion: Optional[str] = None,
**kwargs: Any,
) -> None:
r"""
Expand All @@ -215,13 +216,15 @@ def save_pretrained(
is_main_process (`bool`, *optional*):
Whether the process calling this is the main process or not. Will default to `True`. Will not save the
checkpoint if not on the main process, which is important for multi device setups (e.g. DDP).
convert_pissa_to_lora (`str`):
The path to the initialized PiSSA adapter, which is obtained after initializing the model with PiSSA
and before performing any training. When `convert_pissa_to_lora` is not None, the difference in PISSA
before and after fine-tuning is calculated. This difference can be represented as the parameters of a
of a standard LoRA adapter. Using this converted adapter does not require changes to the base model,
thus conveniently allowing the use of multiple PISSA and LoRA adapters, and the activation or
deactivation of any adapters.
convert_pissa_to_lora (`str, *optional*`):
Deprecated. Use `path_initial_model_for_weight_conversion` instead.
path_initial_model_for_weight_conversion (`str, *optional*`):
The path to the initialized adapter, which is obtained after initializing the model with PiSSA or OLoRA
and before performing any training. When `path_initial_model_for_weight_conversion` is not None, the
difference in adapter before and after fine-tuning is calculated. This difference can be represented
as the parameters of a standard LoRA adapter. Using this converted adapter does not require changes
to the base model, thus conveniently allowing the use of multiple PiSSA or OLoRA adapters with
LoRA adapters, and the activation or deactivation of any adapters.
kwargs (additional keyword arguments, *optional*):
Additional keyword arguments passed along to the `push_to_hub` method.
"""
Expand All @@ -239,20 +242,34 @@ def save_pretrained(
f"You passed an invalid `selected_adapters` arguments, current supported adapter names are"
f" {list(self.peft_config.keys())} - got {selected_adapters}."
)
# TODO: remove deprecated parameter in PEFT v0.14.0
if convert_pissa_to_lora is not None:
warnings.warn(
tokenizer-decode marked this conversation as resolved.
Show resolved Hide resolved
"`convert_pissa_to_lora` is deprecated and will be removed in a future version. "
"Use `path_initial_model_for_weight_conversion` instead."
)
path_initial_model_for_weight_conversion = convert_pissa_to_lora

def save_pissa_as_lora(peft_config, convert_pissa_to_lora, output_state_dict, kwargs):
if not str(peft_config.init_lora_weights).startswith("pissa"):
warnings.warn("`convert_pissa_to_lora` only works for converting a PiSSA adapter to a LoRA adapter")
initial_adapter = os.path.basename(convert_pissa_to_lora)
def save_mutated_as_lora(peft_config, path_initial_model_for_weight_conversion, output_state_dict, kwargs):
tokenizer-decode marked this conversation as resolved.
Show resolved Hide resolved
if not any(str(peft_config.init_lora_weights).lower().startswith(prefix) for prefix in ["pissa", "olora"]):
warnings.warn(
"`path_initial_model_for_weight_conversion` only works for converting a PiSSA or OLoRA adapter to a LoRA adapter"
)
initial_adapter = os.path.basename(path_initial_model_for_weight_conversion)
self.load_adapter(
os.path.dirname(convert_pissa_to_lora), subfolder=initial_adapter, adapter_name=initial_adapter
os.path.dirname(path_initial_model_for_weight_conversion),
subfolder=initial_adapter,
adapter_name=initial_adapter,
)
if str(self.peft_config[initial_adapter].init_lora_weights).startswith("pissa"):
if any(
str(self.peft_config[initial_adapter].init_lora_weights).lower().startswith(prefix)
for prefix in ["pissa", "olora"]
):
raise ValueError(
"The `init_lora_weights` parameter of the initial PiSSA adapter should be set to `True`. "
"Otherwise, `self.load_adapter` will subtract the principal singular value and vector again based on the residual model."
"The `init_lora_weights` parameter of the initial adapter should be set to `True`. "
"Otherwise, `self.load_adapter` will subtract the decomposed values again based on the residual model."
)
output_state_dict = self.base_model.subtract_pissa_init(output_state_dict, initial_adapter, kwargs)
output_state_dict = self.base_model.subtract_mutated_init(output_state_dict, initial_adapter, kwargs)
self.delete_adapter(adapter_name)
return output_state_dict

Expand Down Expand Up @@ -294,19 +311,19 @@ def save_pissa_as_lora(peft_config, convert_pissa_to_lora, output_state_dict, kw
# not supported in safetensors.
for shared_tensor_name in names[1:]:
output_state_dict[shared_tensor_name] = output_state_dict[shared_tensor_name].clone()
if convert_pissa_to_lora is not None:
output_state_dict = save_pissa_as_lora(
peft_config, convert_pissa_to_lora, output_state_dict, kwargs
if path_initial_model_for_weight_conversion is not None:
output_state_dict = save_mutated_as_lora(
peft_config, path_initial_model_for_weight_conversion, output_state_dict, kwargs
)
safe_save_file(
output_state_dict,
os.path.join(output_dir, SAFETENSORS_WEIGHTS_NAME),
metadata={"format": "pt"},
)
elif is_main_process:
if convert_pissa_to_lora is not None:
output_state_dict = save_pissa_as_lora(
peft_config, convert_pissa_to_lora, output_state_dict, kwargs
if path_initial_model_for_weight_conversion is not None:
output_state_dict = save_mutated_as_lora(
peft_config, path_initial_model_for_weight_conversion, output_state_dict, kwargs
)
torch.save(output_state_dict, os.path.join(output_dir, WEIGHTS_NAME))

Expand Down Expand Up @@ -335,7 +352,7 @@ def save_pissa_as_lora(peft_config, convert_pissa_to_lora, output_state_dict, kw
auto_mapping_dict = None

if is_main_process:
if convert_pissa_to_lora is not None:
if path_initial_model_for_weight_conversion is not None:
peft_config.init_lora_weights = True
peft_config.r *= 2
peft_config.lora_alpha *= 2
Expand Down
Loading
Loading