Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lora+ support #1352

Merged
merged 3 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
TrainingArguments,
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer

from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
Expand All @@ -54,6 +56,9 @@
get_cosine_schedule_with_warmup_decay_constant,
)

if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp

try:
import torch._dynamo # pylint: disable=ungrouped-imports
except ImportError:
Expand Down Expand Up @@ -179,6 +184,13 @@ class AxolotlTrainingArguments(TrainingArguments):
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
)
loraplus_lr_embedding: Optional[float] = field(
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)


class AxolotlTrainer(Trainer):
Expand All @@ -203,6 +215,33 @@ def __init__(
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator

def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()

opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
)

loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)

if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)

return self.optimizer

def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
Expand Down Expand Up @@ -915,6 +954,10 @@ def build(self, total_num_steps):
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler
if self.cfg.lr_scheduler
Expand Down
133 changes: 133 additions & 0 deletions src/axolotl/loraplus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""Module for LoRA+"""

# MIT License
#
# Copyright (c) 2024 nikhil-ghosh-berkeley
# https://github.com/nikhil-ghosh-berkeley/loraplus

import logging
from functools import reduce

from peft.tuners import lora
from torch import nn
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names

LOG = logging.getLogger("axolotl.loraplus")


def get_module(name, opt_model):
"""
Retrieve a module from a model using its parameter name.
Args:
name (str): Full name of the parameter, typically including module path.
opt_model (torch.nn.Module): The model from which to retrieve the module.

Returns:
Module corresponding to the given name.
"""
parent_idx = 2 if "lora" in name else 1
module_names = name.split(sep=".")[:-parent_idx]
module = reduce(getattr, module_names, opt_model)
return module


def create_loraplus_optimizer(
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding=None,
):
"""
Creates an optimizer for the given model, applying LoRA-specific learning rate adjustments to different parameter groups.

Args:
opt_model (torch.nn.Module): The model for which the optimizer is being created.
optimizer_cls (class): The class of the optimizer to be used (e.g., torch.optim.Adam).
optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization.
loraplus_lr_ratio (float): The learning rate ratio to be applied to LoRA parameters.
loraplus_lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided.

Returns:
An instance of the specified optimizer class configured with the model's parameters organized into groups with custom learning rates.
"""

assert loraplus_lr_ratio is not None, "loraplus_lr_ratio must be provided."

if loraplus_lr_embedding is None:
loraplus_lr_embedding = 1e-6

decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
param_groups = {
"groupA": {},
"groupB": {},
"groupB_no_decay": {},
"embedding": {},
}

for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue

module = get_module(name, opt_model)
if isinstance(module, lora.Embedding):
param_groups["embedding"][name] = param
elif "lora_B" in name or param.ndim == 1:
if name in decay_parameters:
param_groups["groupB"][name] = param
else:
param_groups["groupB_no_decay"][name] = param
else:
param_groups["groupA"][name] = param

assigned_param_groups = ""
for group, group_params in param_groups.items():
assigned_param_groups += f"{group}\n {list(group_params.keys())}\n\n"
LOG.info(assigned_param_groups)

lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)

optimizer_grouped_parameters = [
{
"params": list(param_groups["groupA"].values()),
"weight_decay": weight_decay,
"lr": lr,
},
{
"params": list(param_groups["embedding"].values()),
"weight_decay": weight_decay,
"lr": loraplus_lr_embedding,
},
{
"params": list(param_groups["groupB"].values()),
"weight_decay": weight_decay,
"lr": lr * loraplus_lr_ratio,
},
{
"params": list(param_groups["groupB_no_decay"].values()),
"weight_decay": 0.0,
"lr": lr * loraplus_lr_ratio,
},
]

optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes

manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{p.data_ptr(): p.numel() for p in module.parameters()}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")

return optimizer
11 changes: 11 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,17 @@ class LoraConfig(BaseModel):
gptq: Optional[bool] = None
bnb_config_kwargs: Optional[Dict[str, Any]] = None

loraplus_lr_ratio: Optional[float] = Field(
default=None,
metadata={
"help": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
},
)
loraplus_lr_embedding: Optional[float] = Field(
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)

merge_lora: Optional[bool] = None

@model_validator(mode="before")
Expand Down
Loading