-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FEAT Add OLoRA initialization strategy to LoRA (#1828)
- Loading branch information
1 parent
8843a76
commit 2f5360a
Showing
9 changed files
with
576 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# OLoRA: Orthonormal Low Rank Adaptation of Large Language Models | ||
|
||
## Introduction | ||
[OLoRA](https://arxiv.org/abs/2406.01775) is a novel approach that leverages orthonormal low rank adaptation through QR decomposition. Unlike the default LoRA implementation, OLoRA decomposes original weights into their $\mathbf{Q}$ and $\mathbf{R}$ parts, and then uses the first `rank` rows of $\mathbf{R}$ and the first `rank` columns of $\mathbf{Q}$ to initialize $\mathbf{A}$ and $\mathbf{B}$, respectively. This results in significantly faster convergence, more stable training, and superior performance. | ||
|
||
## Quick start | ||
```python | ||
import torch | ||
from peft import LoraConfig, get_peft_model | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
from trl import SFTTrainer | ||
from datasets import load_dataset | ||
|
||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16, device_map="auto") | ||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") | ||
dataset = load_dataset("imdb", split="train[:1%]") | ||
lora_config = LoraConfig( | ||
init_lora_weights="olora" | ||
) | ||
peft_model = get_peft_model(model, lora_config) | ||
trainer = SFTTrainer( | ||
model=peft_model, | ||
train_dataset=dataset, | ||
dataset_text_field="text", | ||
max_seq_length=512, | ||
tokenizer=tokenizer, | ||
) | ||
trainer.train() | ||
peft_model.save_pretrained("olora-opt-350m") | ||
``` | ||
|
||
There is no additional change needed to your standard LoRA procedure, except for specifying `init_lora_weights = "olora"` option in your lora configuration. | ||
|
||
Additionally you can refer to olora finetuning script. | ||
Run the script simply by running: | ||
```bash | ||
python3 examples/olora_finetuning/olora_finetuning.py --base_model facebook/opt-350m | ||
``` | ||
OLoRA also supports quantization. To use 4-bit quantization try: | ||
```bash | ||
python3 examples/olora_finetuning/olora_finetuning.py --base_model facebook/opt-350m --quantize | ||
``` | ||
|
||
|
||
## Use the model | ||
You can load and use the model as any other 🤗 PEFT model | ||
```python | ||
from peft import PeftModel | ||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m") | ||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") | ||
olora_model = PeftModel.from_pretrained(model, "olora-opt-350m") | ||
``` | ||
|
||
## OLoRA and LoRA | ||
OLoRA differs from LoRA in that it mutates the original weights. To utilize multiple adapters simultaneously, you can leverage the `path_initial_model_for_weight_conversion` option. Below is a simple template illustrating how to convert OLoRA to conventional LoRA: | ||
```python | ||
base_model = AutoModel.from_pretrained("facebook/opt-350m") | ||
olora_config = LoraConfig( | ||
... | ||
init_lora_weights = "olora" # Initialize the model with OLoRA | ||
) | ||
olora_model = get_peft_model(base_model, olora_config) | ||
init_path = <path-to-untrained-olora-model> | ||
olora_model.save_pretrained(init_path) # Save the model *before* performing any training | ||
|
||
# Train the model | ||
train(olora_model) # Your training loop | ||
|
||
#Save the model after training | ||
olora_model.save_pretrained(output_dir, path_initial_model_for_weight_conversion=init_path) | ||
``` | ||
After completing training, you can save and convert your OLoRA model to a conventional LoRA model by setting `path_initial_model_for_weight_conversion` to `init_path`, that is the path of your untrained OLoRA model. This conversion enables you to use multiple adapters with your LoRA model. | ||
|
||
## Citation | ||
``` | ||
@misc{büyükakyüz2024olora, | ||
title={OLoRA: Orthonormal Low-Rank Adaptation of Large Language Models}, | ||
author={Kerim Büyükakyüz}, | ||
year={2024}, | ||
eprint={2406.01775}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.CL} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
# Copyright 2024-present the HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from typing import List | ||
|
||
import torch | ||
import transformers | ||
from datasets import load_dataset | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | ||
|
||
from peft import ( | ||
LoraConfig, | ||
get_peft_model, | ||
) | ||
|
||
|
||
def train( | ||
base_model: str = "path/to/model", | ||
data_path: str = "yahma/alpaca-cleaned", | ||
output_dir: str = "olora", | ||
batch_size: int = 16, | ||
num_epochs: int = 1, | ||
learning_rate: float = 3e-4, | ||
cutoff_len: int = 256, | ||
val_set_size: int = 16, | ||
quantize: bool = False, | ||
eval_step: int = 100, | ||
save_step: int = 100, | ||
device_map: str = "auto", | ||
lora_r: int = 32, | ||
lora_alpha: int = 16, | ||
lora_dropout: float = 0.05, | ||
lora_target_modules: List[str] = None, | ||
init_lora_weights="olora", | ||
): | ||
model = AutoModelForCausalLM.from_pretrained( | ||
base_model, | ||
device_map=device_map, | ||
quantization_config=BitsAndBytesConfig( | ||
load_in_4bit=True, | ||
bnb_4bit_compute_dtype=torch.bfloat16, | ||
bnb_4bit_use_double_quant=True, | ||
bnb_4bit_quant_type="nf4", | ||
) | ||
if quantize | ||
else None, | ||
torch_dtype=torch.float16, | ||
) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) | ||
|
||
def tokenize(prompt, add_eos_token=True): | ||
result = tokenizer( | ||
prompt, | ||
truncation=True, | ||
max_length=cutoff_len, | ||
padding=False, | ||
return_tensors=None, | ||
) | ||
if ( | ||
result["input_ids"][-1] != tokenizer.eos_token_id | ||
and len(result["input_ids"]) < cutoff_len | ||
and add_eos_token | ||
): | ||
result["input_ids"].append(tokenizer.eos_token_id) | ||
result["attention_mask"].append(1) | ||
|
||
result["labels"] = result["input_ids"].copy() | ||
|
||
return result | ||
|
||
def generate_and_tokenize_prompt(example): | ||
full_prompt = generate_prompt(example) | ||
tokenized_full_prompt = tokenize(full_prompt) | ||
return tokenized_full_prompt | ||
|
||
config = LoraConfig( | ||
r=lora_r, | ||
lora_alpha=lora_alpha, | ||
target_modules=lora_target_modules, | ||
lora_dropout=lora_dropout, | ||
bias="none", | ||
task_type="CAUSAL_LM", | ||
init_lora_weights=init_lora_weights, | ||
) | ||
model = get_peft_model(model, config) | ||
|
||
data = load_dataset(data_path) | ||
|
||
train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42) | ||
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt) | ||
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt) | ||
|
||
trainer = transformers.Trainer( | ||
model=model, | ||
train_dataset=train_data, | ||
eval_dataset=val_data, | ||
args=transformers.TrainingArguments( | ||
per_device_train_batch_size=batch_size, | ||
warmup_steps=100, | ||
num_train_epochs=num_epochs, | ||
learning_rate=learning_rate, | ||
fp16=True, | ||
logging_steps=100, | ||
optim="adamw_torch", | ||
evaluation_strategy="steps", | ||
save_strategy="steps", | ||
eval_steps=eval_step, | ||
save_steps=save_step, | ||
output_dir=output_dir, | ||
save_total_limit=3, | ||
load_best_model_at_end=True, | ||
), | ||
data_collator=transformers.DataCollatorForSeq2Seq( | ||
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True | ||
), | ||
) | ||
trainer.train() | ||
model.save_pretrained(output_dir) | ||
|
||
|
||
def generate_prompt(example): | ||
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. | ||
### Instruction: | ||
{example["instruction"]} | ||
### Response: | ||
{example["output"]}""" | ||
|
||
|
||
if __name__ == "__main__": | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--base_model", type=str, default="path/to/model") | ||
parser.add_argument("--data_path", type=str, default="yahma/alpaca-cleaned") | ||
parser.add_argument("--output_dir", type=str, default="olora") | ||
parser.add_argument("--batch_size", type=int, default=16) | ||
parser.add_argument("--num_epochs", type=int, default=1) | ||
parser.add_argument("--learning_rate", type=float, default=3e-4) | ||
parser.add_argument("--cutoff_len", type=int, default=256) | ||
parser.add_argument("--val_set_size", type=int, default=16) | ||
parser.add_argument("--quantize", action="store_true") | ||
parser.add_argument("--eval_step", type=int, default=100) | ||
parser.add_argument("--save_step", type=int, default=100) | ||
parser.add_argument("--device_map", type=str, default="auto") | ||
parser.add_argument("--lora_r", type=int, default=32) | ||
parser.add_argument("--lora_alpha", type=int, default=16) | ||
parser.add_argument("--lora_dropout", type=float, default=0.05) | ||
parser.add_argument("--lora_target_modules", type=str, default=None) | ||
parser.add_argument("--init_lora_weights", type=str, default="olora") | ||
|
||
args = parser.parse_args() | ||
|
||
train( | ||
base_model=args.base_model, | ||
data_path=args.data_path, | ||
output_dir=args.output_dir, | ||
batch_size=args.batch_size, | ||
num_epochs=args.num_epochs, | ||
learning_rate=args.learning_rate, | ||
cutoff_len=args.cutoff_len, | ||
val_set_size=args.val_set_size, | ||
quantize=args.quantize, | ||
eval_step=args.eval_step, | ||
save_step=args.save_step, | ||
device_map=args.device_map, | ||
lora_r=args.lora_r, | ||
lora_alpha=args.lora_alpha, | ||
lora_dropout=args.lora_dropout, | ||
lora_target_modules=args.lora_target_modules, | ||
init_lora_weights=args.init_lora_weights, | ||
) |
Oops, something went wrong.