-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Integrating Riemannian Preconditioner #1807
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0852cfa
550553a
dd32e8b
3dc3f9e
dd7e85c
735e386
445412c
4526263
3017a5d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| import torch | ||
| from datasets import load_dataset, load_metric | ||
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments | ||
|
|
||
| from peft import LoraConfig, TaskType, create_riemannian_optimizer, get_peft_model | ||
|
|
||
|
|
||
| model_checkpoint = "microsoft/deberta-v3-base" | ||
| dataset = load_dataset("glue", "cola") | ||
| metric = load_metric("glue", "cola") | ||
| tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | ||
| task_to_keys = {"cola": ("sentence", None)} | ||
| sentence1_key, sentence2_key = task_to_keys["cola"] | ||
|
|
||
|
|
||
| def preprocess_function(examples): | ||
| if sentence2_key is None: | ||
| return tokenizer(examples[sentence1_key], truncation=True) | ||
| return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True) | ||
|
|
||
|
|
||
| encoded_dataset = dataset.map(preprocess_function, batched=True) | ||
| num_labels = 2 | ||
| model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels) | ||
| peft_config = LoraConfig( | ||
| task_type=TaskType.SEQ_CLS, | ||
| inference_mode=False, | ||
| r=4, | ||
| lora_alpha=8, | ||
| lora_dropout=0.01, | ||
| target_modules=["query_proj", "key_proj", "value_proj"], | ||
| ) | ||
| model = get_peft_model(model, peft_config) | ||
| model.print_trainable_parameters() | ||
| metric_name = "matthews_correlation" | ||
|
|
||
| args = TrainingArguments( | ||
| "glue_tune", save_strategy="epoch", per_device_train_batch_size=8, num_train_epochs=3, logging_steps=10, seed=0 | ||
| ) | ||
|
|
||
| optim_config = {"lr": 1e-5, "eps": 1e-6, "betas": (0.9, 0.999), "weight_decay": 0.01} | ||
|
|
||
| optimizer = create_riemannian_optimizer( | ||
| model=model, optimizer_cls=torch.optim.AdamW, optimizer_kwargs=optim_config, reg=1e-2 | ||
| ) | ||
| trainer = Trainer( | ||
| model, args, train_dataset=encoded_dataset["train"], tokenizer=tokenizer, optimizers=[optimizer, None] | ||
| ) | ||
| trainer.train() |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,20 @@ | ||||||
| # flake8: noqa | ||||||
| # There's no way to ignore "F401 '...' imported but unused" warnings in this | ||||||
| # module, but to preserve other warnings. So, don't check this module at all | ||||||
|
|
||||||
| # coding=utf-8 | ||||||
| # Copyright 2023-present the HuggingFace Inc. team. | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| # | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
| # you may not use this file except in compliance with the License. | ||||||
| # You may obtain a copy of the License at | ||||||
| # | ||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||
| # | ||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | ||||||
|
|
||||||
| from .riemannian import create_riemannian_optimizer | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,294 @@ | ||
| import math | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add the copyright notice to all new files. |
||
| from operator import attrgetter | ||
| from typing import Callable, Iterable, Tuple | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from torch.optim import Optimizer | ||
| from transformers.utils.versions import require_version | ||
|
|
||
| from ..peft_model import PeftModel | ||
|
|
||
|
|
||
| def create_riemannian_optimizer( | ||
| model: PeftModel, | ||
| optimizer_cls: type[Optimizer], | ||
| optimizer_kwargs: dict, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since you probably took this from the LoRA+ PR, let me refer to the comment I put there: A suggestion: Let's remove |
||
| lr_embedding: float = 1e-6, | ||
| reg: float = 1e-4, | ||
| ) -> Optimizer: | ||
| """ | ||
| Creates a Riemmanian optimizer. Used only for LoRA. Implementation based on: | ||
| https://github.com/pilancilab/Riemannian_Preconditioned_LoRA. Paper reference: https://arxiv.org/pdf/2402.02347. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add that the optimizer is a modified version of |
||
|
|
||
| Args: | ||
| model (`torch.nn.Module`): The model to be optimized. | ||
| optimizer_cls (`torch.optim.Optimizer`): The optimizer class to be used. | ||
| optimizer_kwargs (`dict`): Additional keyword arguments to be passed to the optimizer. | ||
| - lr_embedding (`float`): The learning rate to be used for the embedding layer. Defaults to lr_embedding | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's use the same indentation and syntax as the other parameters. Also, let's add docs for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, indentation is still wrong. It should be: |
||
| - reg (`float`): Regularization parameter for Riemmanian preconditioner. Included for lora parameters only | ||
| and is needed for invertability guarantees | ||
| """ | ||
|
|
||
| # TEST VERSION FOR ADAMW | ||
| if not issubclass(optimizer_cls, torch.optim.AdamW): | ||
| raise TypeError("TEST version only supports AdamW optimizer") | ||
|
Comment on lines
+34
to
+35
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the |
||
| from ..tuners.lora.layer import Embedding | ||
|
|
||
| param_groups = {"lora_params": {}, "other_params": {}, "embedding": {}} | ||
|
|
||
| for name, param in model.named_parameters(): | ||
| if not param.requires_grad: | ||
| continue | ||
| module = attrgetter(name)(model) | ||
| if isinstance(module, Embedding): | ||
| param_groups["embedding"][name] = param | ||
| elif "lora" in name: | ||
| param_groups["lora_params"][name] = param | ||
| else: | ||
| param_groups["other_params"][name] = param | ||
|
|
||
| lr = optimizer_kwargs["lr"] | ||
| weight_decay = optimizer_kwargs.get("weight_decay", 0.0) | ||
|
|
||
| optimizer_grouped_parameters = [ | ||
| { | ||
| "params": list(param_groups["lora_params"].values()), | ||
| "weight_decay": weight_decay, | ||
| "lr": lr, | ||
| "is_lora": True, | ||
| "reg": reg, | ||
| }, | ||
| { | ||
| "params": list(param_groups["embedding"].values()), | ||
| "weight_decay": weight_decay, | ||
| "lr": lr_embedding, | ||
| "is_lora": False, | ||
| }, | ||
| { | ||
| "params": list(param_groups["other_params"].values()), | ||
| "weight_decay": weight_decay, | ||
| "lr": lr, | ||
| "is_lora": False, | ||
| }, | ||
| ] | ||
|
|
||
| optimizer = RiemannianAdamW(optimizer_grouped_parameters, **optimizer_kwargs) | ||
| return optimizer | ||
|
|
||
|
|
||
| class RiemannianAdamW(Optimizer): | ||
| """ | ||
| This implements a variant of AdamW which combines with Riemannian preconditioner and is specifically designed for | ||
| LoRA model. Basic AdamW workflow follows implmentation of transformers.optimization.AdamW, with exception that | ||
| Riemmanian preconditioners are added. More specifically, for each LoRA parameter pair (lora_up, lora_down), we use | ||
| gradient of form | ||
| grad(lora_up) = lora_up@inverse(lora_down.T@lora_down) | ||
| and vice versa. | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's mention that this is modified from the transformers AdamW implementation, not the torch one. It would also be nice to mark the lines in the function body that were changed compared to the original with a comment, if possible (probably everything inside of |
||
| Parameters: | ||
| params (`Iterable[nn.parameter.Parameter]`): | ||
| Iterable of parameters to optimize or dictionaries defining parameter groups. | ||
| lr (`float`, *optional*, defaults to 0.001): | ||
| The learning rate to use. | ||
| betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`): | ||
| Adam's betas parameters (b1, b2). | ||
| eps (`float`, *optional*, defaults to 1e-06): | ||
| Adam's epsilon for numerical stability. | ||
| weight_decay (`float`, *optional*, defaults to 0.0): | ||
| Decoupled weight decay to apply. | ||
| correct_bias (`bool`, *optional*, defaults to `True`): | ||
| Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). | ||
| no_deprecation_warning (`bool`, *optional*, defaults to `False`): | ||
| A flag used to disable the deprecation warning (set to `True` to disable the warning). | ||
| """ | ||
|
|
||
| def __init__( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of implementing a custom Also, note that I'm not sure why that's needed, but in case we decide not to inherit, we should probably copy that function over too. Not sure if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the detailed checking! I'm actually replicating transformer's AdamW implementation here: https://github.com/huggingface/transformers/blob/v4.42.0/src/transformers/optimization.py#L558, and slightly modify it to incorporate our desired gradient change. I'm not pretty sure whether we want to follow the torch.optim implementation or it's fine to follow transformer's implementation?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I didn't know this. In that case, it's fine. |
||
| self, | ||
| params: Iterable[nn.parameter.Parameter], | ||
| lr: float = 1e-3, | ||
| betas: Tuple[float, float] = (0.9, 0.999), | ||
| eps: float = 1e-6, | ||
| weight_decay: float = 0.0, | ||
| correct_bias: bool = True, | ||
| ): | ||
| require_version("torch>=1.5.0") # add_ with alpha | ||
| if lr < 0.0: | ||
| raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") | ||
| if not 0.0 <= betas[0] < 1.0: | ||
| raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)") | ||
| if not 0.0 <= betas[1] < 1.0: | ||
| raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") | ||
| if not 0.0 <= eps: | ||
| raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") | ||
| defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias} | ||
| super().__init__(params, defaults) | ||
|
|
||
| @torch.no_grad() | ||
| def step(self, closure: Callable = None): | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Performs a single optimization step. | ||
|
|
||
| Arguments: | ||
| closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. | ||
| """ | ||
| loss = None | ||
| if closure is not None: | ||
| loss = closure() | ||
|
|
||
| for group in self.param_groups: | ||
| if group["is_lora"]: | ||
| for p1, p2 in list(zip(group["params"], group["params"][1:]))[::2]: | ||
|
Comment on lines
+139
to
+141
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed in the other comment, this is indeed error prone. For this, the logic here: should be improved. I think it's better if we create two separate groups for After making this change, the line here could be simplified to: # this works because there is exactly one lora_A and one lora_B group
lora_A_params = next(group for group in self.param_groups if group["is_lora_A"])
lora_B_params = next(group for group in self.param_groups if group["is_lora_B"])
for p1, p2 in zip(lora_A_params, lora_B_params): |
||
| grad = p1.grad | ||
| if grad.is_sparse: | ||
| raise RuntimeError( | ||
| "Adam does not support sparse gradients, please consider SparseAdam instead" | ||
| ) | ||
|
|
||
| state = self.state[p1] | ||
| # State initialization | ||
| if len(state) == 0: | ||
| state["step"] = 0 | ||
| # Exponential moving average of gradient values | ||
| state["exp_avg"] = torch.zeros_like(p1) | ||
| # Exponential moving average of squared gradient values | ||
| state["exp_avg_sq"] = torch.zeros_like(p1) | ||
|
|
||
| exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] | ||
| beta1, beta2 = group["betas"] | ||
| state["step"] += 1 | ||
| scaler = p2.data | ||
| scaler_temp = p1.data | ||
| try: | ||
| reg_I = group["reg"] * torch.eye(min(p2.shape)).to(p2.device) | ||
| scaler = ( | ||
| torch.inverse(scaler @ scaler.T + reg_I) | ||
| if p2.shape[0] < p2.shape[1] | ||
| else torch.inverse(scaler.T @ scaler + reg_I) | ||
| ) | ||
| assert scaler.shape[0] == min(p2.data.shape), "wrong dimension" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's not use assert, instead raise a proper |
||
| except RuntimeError: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you explain why this is needed? Could we instead check the condition and do something like |
||
| print("invalid condition") | ||
| scaler = None | ||
|
|
||
| # apply riemannian conditioner | ||
| if scaler is not None: | ||
| grad = grad @ scaler if grad.shape[1] == scaler.shape[0] else scaler @ grad | ||
| # Decay the first and second moment running average coefficient | ||
| # In-place operations to update the averages at the same time | ||
| exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) | ||
| exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) | ||
| denom = exp_avg_sq.sqrt().add_(group["eps"]) | ||
|
|
||
| step_size = group["lr"] | ||
| if group["correct_bias"]: # No bias correction for Bert | ||
| bias_correction1 = 1.0 - beta1 ** state["step"] | ||
| bias_correction2 = 1.0 - beta2 ** state["step"] | ||
| step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 | ||
|
|
||
| p1.addcdiv_(exp_avg, denom, value=-step_size) | ||
| if group["weight_decay"] > 0.0: | ||
| p1.add_(p1, alpha=(-group["lr"] * group["weight_decay"])) | ||
|
|
||
| grad = p2.grad | ||
| if grad.is_sparse: | ||
| raise RuntimeError( | ||
| "Adam does not support sparse gradients, please consider SparseAdam instead" | ||
| ) | ||
|
|
||
| state = self.state[p2] | ||
| # State initialization | ||
| if len(state) == 0: | ||
| state["step"] = 0 | ||
| # Exponential moving average of gradient values | ||
| state["exp_avg"] = torch.zeros_like(p2) | ||
| # Exponential moving average of squared gradient values | ||
| state["exp_avg_sq"] = torch.zeros_like(p2) | ||
|
|
||
| exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] | ||
| beta1, beta2 = group["betas"] | ||
| state["step"] += 1 | ||
| scaler = scaler_temp | ||
| try: | ||
| reg_I = group["lr"] * torch.eye(min(p1.shape)).to(p1.device) | ||
| scaler = ( | ||
| torch.inverse(scaler @ scaler.T + reg_I) | ||
| if p1.shape[0] < p1.shape[1] | ||
| else torch.inverse(scaler.T @ scaler + reg_I) | ||
| ) | ||
| assert scaler.shape[0] == min(p1.data.shape), "wrong dimension" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's not use assert, instead raise a proper ValueError with a helpful message. |
||
| except RuntimeError: | ||
| print("invalid condition") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you explain why this is needed? Could we instead check the condition and do something like |
||
| scaler = None | ||
|
|
||
| # apply riemannian conditioner | ||
| if scaler is not None: | ||
| grad = grad @ scaler if grad.shape[1] == scaler.shape[0] else scaler @ grad | ||
| # Decay the first and second moment running average coefficient | ||
| # In-place operations to update the averages at the same time | ||
| exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) | ||
| exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) | ||
| denom = exp_avg_sq.sqrt().add_(group["eps"]) | ||
|
|
||
| step_size = group["lr"] | ||
| if group["correct_bias"]: # No bias correction for Bert | ||
| bias_correction1 = 1.0 - beta1 ** state["step"] | ||
| bias_correction2 = 1.0 - beta2 ** state["step"] | ||
| step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 | ||
|
|
||
| p2.addcdiv_(exp_avg, denom, value=-step_size) | ||
|
|
||
| if group["weight_decay"] > 0.0: | ||
| p2.add_(p2, alpha=(-group["lr"] * group["weight_decay"])) | ||
|
|
||
| else: | ||
| for p in group["params"]: | ||
| if p.grad is None: | ||
| continue | ||
| grad = p.grad | ||
| if grad.is_sparse: | ||
| raise RuntimeError( | ||
| "Adam does not support sparse gradients, please consider SparseAdam instead" | ||
| ) | ||
|
|
||
| state = self.state[p] | ||
|
|
||
| # State initialization | ||
| if len(state) == 0: | ||
| state["step"] = 0 | ||
| # Exponential moving average of gradient values | ||
| state["exp_avg"] = torch.zeros_like(p) | ||
| # Exponential moving average of squared gradient values | ||
| state["exp_avg_sq"] = torch.zeros_like(p) | ||
|
|
||
| exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] | ||
| beta1, beta2 = group["betas"] | ||
|
|
||
| state["step"] += 1 | ||
|
|
||
| # Decay the first and second moment running average coefficient | ||
| # In-place operations to update the averages at the same time | ||
| exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) | ||
| exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) | ||
| denom = exp_avg_sq.sqrt().add_(group["eps"]) | ||
|
|
||
| step_size = group["lr"] | ||
| if group["correct_bias"]: # No bias correction for Bert | ||
| bias_correction1 = 1.0 - beta1 ** state["step"] | ||
| bias_correction2 = 1.0 - beta2 ** state["step"] | ||
| step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 | ||
|
|
||
| p.addcdiv_(exp_avg, denom, value=-step_size) | ||
|
|
||
| # Just adding the square of the weights to the loss function is *not* | ||
| # the correct way of using L2 regularization/weight decay with Adam, | ||
| # since that will interact with the m and v parameters in strange ways. | ||
| # | ||
| # Instead we want to decay the weights in a manner that doesn't interact | ||
| # with the m/v parameters. This is equivalent to adding the square | ||
| # of the weights to the loss with plain (non-momentum) SGD. | ||
| # Add weight decay at the end (fixed version) | ||
| if group["weight_decay"] > 0.0: | ||
| p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) | ||
|
|
||
| return loss | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These lines can be removed. At the bottom of the file, add
__all__ = ["create_riemannian_optimizer"]