Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions examples/riemannian_lora/riemannian_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from datasets import load_dataset, load_metric
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments

from peft import LoraConfig, TaskType, create_riemannian_optimizer, get_peft_model


model_checkpoint = "microsoft/deberta-v3-base"
dataset = load_dataset("glue", "cola")
metric = load_metric("glue", "cola")
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
task_to_keys = {"cola": ("sentence", None)}
sentence1_key, sentence2_key = task_to_keys["cola"]


def preprocess_function(examples):
if sentence2_key is None:
return tokenizer(examples[sentence1_key], truncation=True)
return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)


encoded_dataset = dataset.map(preprocess_function, batched=True)
num_labels = 2
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
r=4,
lora_alpha=8,
lora_dropout=0.01,
target_modules=["query_proj", "key_proj", "value_proj"],
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
metric_name = "matthews_correlation"

args = TrainingArguments(
"glue_tune", save_strategy="epoch", per_device_train_batch_size=8, num_train_epochs=3, logging_steps=10, seed=0
)

optim_config = {"lr": 1e-5, "eps": 1e-6, "betas": (0.9, 0.999), "weight_decay": 0.01}

optimizer = create_riemannian_optimizer(
model=model, optimizer_cls=torch.optim.AdamW, optimizer_kwargs=optim_config, reg=1e-2
)
trainer = Trainer(
model, args, train_dataset=encoded_dataset["train"], tokenizer=tokenizer, optimizers=[optimizer, None]
)
trainer.train()
2 changes: 2 additions & 0 deletions src/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,5 @@
cast_mixed_precision_params,
)
from .config import PeftConfig, PromptLearningConfig

from .optimizers import create_riemannian_optimizer
20 changes: 20 additions & 0 deletions src/peft/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all

# coding=utf-8
Comment on lines +1 to +5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines can be removed. At the bottom of the file, add __all__ = ["create_riemannian_optimizer"]

# Copyright 2023-present the HuggingFace Inc. team.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2023-present the HuggingFace Inc. team.
# Copyright 2024-present the HuggingFace Inc. team.

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .riemannian import create_riemannian_optimizer
294 changes: 294 additions & 0 deletions src/peft/optimizers/riemannian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
import math
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add the copyright notice to all new files.

from operator import attrgetter
from typing import Callable, Iterable, Tuple

import torch
import torch.nn as nn
from torch.optim import Optimizer
from transformers.utils.versions import require_version

from ..peft_model import PeftModel


def create_riemannian_optimizer(
model: PeftModel,
optimizer_cls: type[Optimizer],
optimizer_kwargs: dict,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you probably took this from the LoRA+ PR, let me refer to the comment I put there:

A suggestion: Let's remove optimizer_kwargs and just add **kwargs. IMO, that makes calling this function easier, as we can use create_riemannian_optimizer(..., weight_decay=1e-3) instead of create_riemannian_optimizer(..., optimizer_kwargs={..., "weight_decay": 1e-3}). And since lr is not optional, let's make this a normal arg of create_riemannian_optimizer.

lr_embedding: float = 1e-6,
reg: float = 1e-4,
) -> Optimizer:
"""
Creates a Riemmanian optimizer. Used only for LoRA. Implementation based on:
https://github.com/pilancilab/Riemannian_Preconditioned_LoRA. Paper reference: https://arxiv.org/pdf/2402.02347.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add that the optimizer is a modified version of AdamW.


Args:
model (`torch.nn.Module`): The model to be optimized.
optimizer_cls (`torch.optim.Optimizer`): The optimizer class to be used.
optimizer_kwargs (`dict`): Additional keyword arguments to be passed to the optimizer.
- lr_embedding (`float`): The learning rate to be used for the embedding layer. Defaults to lr_embedding
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use the same indentation and syntax as the other parameters. Also, let's add docs for reg.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, indentation is still wrong. It should be:

        optimizer_kwargs (`dict`): Additional keyword arguments to be passed to the optimizer.
        lr_embedding (`float`): The learning rate to be used for the embedding layer. Defaults to lr_embedding
        reg (`float`): Regularization parameter for Riemmanian preconditioner. Included for lora parameters only

- reg (`float`): Regularization parameter for Riemmanian preconditioner. Included for lora parameters only
and is needed for invertability guarantees
"""

# TEST VERSION FOR ADAMW
if not issubclass(optimizer_cls, torch.optim.AdamW):
raise TypeError("TEST version only supports AdamW optimizer")
Comment on lines +34 to +35
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the optimizer_cls argument is not actually except to raise an error, how about removing it completely?

from ..tuners.lora.layer import Embedding

param_groups = {"lora_params": {}, "other_params": {}, "embedding": {}}

for name, param in model.named_parameters():
if not param.requires_grad:
continue
module = attrgetter(name)(model)
if isinstance(module, Embedding):
param_groups["embedding"][name] = param
elif "lora" in name:
param_groups["lora_params"][name] = param
else:
param_groups["other_params"][name] = param

lr = optimizer_kwargs["lr"]
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)

optimizer_grouped_parameters = [
{
"params": list(param_groups["lora_params"].values()),
"weight_decay": weight_decay,
"lr": lr,
"is_lora": True,
"reg": reg,
},
{
"params": list(param_groups["embedding"].values()),
"weight_decay": weight_decay,
"lr": lr_embedding,
"is_lora": False,
},
{
"params": list(param_groups["other_params"].values()),
"weight_decay": weight_decay,
"lr": lr,
"is_lora": False,
},
]

optimizer = RiemannianAdamW(optimizer_grouped_parameters, **optimizer_kwargs)
return optimizer


class RiemannianAdamW(Optimizer):
"""
This implements a variant of AdamW which combines with Riemannian preconditioner and is specifically designed for
LoRA model. Basic AdamW workflow follows implmentation of transformers.optimization.AdamW, with exception that
Riemmanian preconditioners are added. More specifically, for each LoRA parameter pair (lora_up, lora_down), we use
gradient of form
grad(lora_up) = lora_up@inverse(lora_down.T@lora_down)
and vice versa.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's mention that this is modified from the transformers AdamW implementation, not the torch one. It would also be nice to mark the lines in the function body that were changed compared to the original with a comment, if possible (probably everything inside of if group["is_lora"]:).

Parameters:
params (`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (`float`, *optional*, defaults to 0.001):
The learning rate to use.
betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`):
Adam's betas parameters (b1, b2).
eps (`float`, *optional*, defaults to 1e-06):
Adam's epsilon for numerical stability.
weight_decay (`float`, *optional*, defaults to 0.0):
Decoupled weight decay to apply.
correct_bias (`bool`, *optional*, defaults to `True`):
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
A flag used to disable the deprecation warning (set to `True` to disable the warning).
"""

def __init__(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of implementing a custom __init__, I wonder if this class could not inherit from torch.optim.AdamW and re-use the __init__ from there.

Also, note that torch.optim.AdamW has a custom __setstate__:

https://github.com/pytorch/pytorch/blob/97ff6cfd9c86c5c09d7ce775ab64ec5c99230f5d/torch/optim/adamw.py#L74-L89

I'm not sure why that's needed, but in case we decide not to inherit, we should probably copy that function over too.

Not sure if _init_group is also needed:

https://github.com/pytorch/pytorch/blob/97ff6cfd9c86c5c09d7ce775ab64ec5c99230f5d/torch/optim/adamw.py#L91

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed checking! I'm actually replicating transformer's AdamW implementation here: https://github.com/huggingface/transformers/blob/v4.42.0/src/transformers/optimization.py#L558, and slightly modify it to incorporate our desired gradient change. I'm not pretty sure whether we want to follow the torch.optim implementation or it's fine to follow transformer's implementation?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I didn't know this. In that case, it's fine.

self,
params: Iterable[nn.parameter.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.0,
correct_bias: bool = True,
):
require_version("torch>=1.5.0") # add_ with alpha
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
super().__init__(params, defaults)

@torch.no_grad()
def step(self, closure: Callable = None):
"""
Performs a single optimization step.

Arguments:
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
if group["is_lora"]:
for p1, p2 in list(zip(group["params"], group["params"][1:]))[::2]:
Comment on lines +139 to +141
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in the other comment, this is indeed error prone. For this, the logic here:

https://github.com/huggingface/peft/pull/1807/files#diff-4730f831ea49f19ef126ffa6d712865c57a477585e4098b74acb6026d3056d5aR46-R47

should be improved. I think it's better if we create two separate groups for lora_A and lora_B. After the loop there, let's also check that both groups have the same length and that the length is > 0. In the optimizer_grouped_parameters, we can set "is_lora_A": True and "is_lora_B": True accordingly.

After making this change, the line here could be simplified to:

# this works because there is exactly one lora_A and one lora_B group
lora_A_params = next(group for group in self.param_groups if group["is_lora_A"])
lora_B_params = next(group for group in self.param_groups if group["is_lora_B"])
for p1, p2 in zip(lora_A_params, lora_B_params):

grad = p1.grad
if grad.is_sparse:
raise RuntimeError(
"Adam does not support sparse gradients, please consider SparseAdam instead"
)

state = self.state[p1]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p1)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p1)

exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
scaler = p2.data
scaler_temp = p1.data
try:
reg_I = group["reg"] * torch.eye(min(p2.shape)).to(p2.device)
scaler = (
torch.inverse(scaler @ scaler.T + reg_I)
if p2.shape[0] < p2.shape[1]
else torch.inverse(scaler.T @ scaler + reg_I)
)
assert scaler.shape[0] == min(p2.data.shape), "wrong dimension"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not use assert, instead raise a proper ValueError with a helpful message.

except RuntimeError:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why this is needed? Could we instead check the condition and do something like if valid_condition: ... else: scaler = None. Let's completely avoid printing messages.

print("invalid condition")
scaler = None

# apply riemannian conditioner
if scaler is not None:
grad = grad @ scaler if grad.shape[1] == scaler.shape[0] else scaler @ grad
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])

step_size = group["lr"]
if group["correct_bias"]: # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** state["step"]
bias_correction2 = 1.0 - beta2 ** state["step"]
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

p1.addcdiv_(exp_avg, denom, value=-step_size)
if group["weight_decay"] > 0.0:
p1.add_(p1, alpha=(-group["lr"] * group["weight_decay"]))

grad = p2.grad
if grad.is_sparse:
raise RuntimeError(
"Adam does not support sparse gradients, please consider SparseAdam instead"
)

state = self.state[p2]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p2)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p2)

exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
scaler = scaler_temp
try:
reg_I = group["lr"] * torch.eye(min(p1.shape)).to(p1.device)
scaler = (
torch.inverse(scaler @ scaler.T + reg_I)
if p1.shape[0] < p1.shape[1]
else torch.inverse(scaler.T @ scaler + reg_I)
)
assert scaler.shape[0] == min(p1.data.shape), "wrong dimension"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not use assert, instead raise a proper ValueError with a helpful message.

except RuntimeError:
print("invalid condition")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why this is needed? Could we instead check the condition and do something like if valid_condition: ... else: scaler = None. Let's completely avoid printing messages.

scaler = None

# apply riemannian conditioner
if scaler is not None:
grad = grad @ scaler if grad.shape[1] == scaler.shape[0] else scaler @ grad
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])

step_size = group["lr"]
if group["correct_bias"]: # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** state["step"]
bias_correction2 = 1.0 - beta2 ** state["step"]
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

p2.addcdiv_(exp_avg, denom, value=-step_size)

if group["weight_decay"] > 0.0:
p2.add_(p2, alpha=(-group["lr"] * group["weight_decay"]))

else:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"Adam does not support sparse gradients, please consider SparseAdam instead"
)

state = self.state[p]

# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p)

exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]

state["step"] += 1

# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])

step_size = group["lr"]
if group["correct_bias"]: # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** state["step"]
bias_correction2 = 1.0 - beta2 ** state["step"]
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

p.addcdiv_(exp_avg, denom, value=-step_size)

# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group["weight_decay"] > 0.0:
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))

return loss
Loading