Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

relora: magnitude pruning of the optimizer #1245

Merged
merged 15 commits into from
Feb 6, 2024
9 changes: 9 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
Expand Down Expand Up @@ -478,10 +482,14 @@ def create_scheduler(
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
anneal_steps,
warmup_steps,
)
else:
Expand Down Expand Up @@ -893,6 +901,7 @@ def build(self, total_num_steps):
] = self.cfg.micro_batch_size
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
training_arguments_kwargs["relora_anneal_steps"] = self.cfg.relora_anneal_steps
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
Expand Down
120 changes: 97 additions & 23 deletions src/axolotl/monkeypatch/relora.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
import logging
import os.path
import shutil
from functools import partial
from pathlib import Path
from typing import Dict, List, Sequence
from typing import Dict, List, Sequence, Union

import bitsandbytes as bnb
import peft
import safetensors.torch as st
import torch
from huggingface_hub import snapshot_download
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from transformers import (
Expand All @@ -23,23 +25,50 @@
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR

from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.distributed import barrier, is_main_process

LOG = logging.getLogger("axolotl.relora")


def reset_optimizer(optimizer: torch.optim.Optimizer):
for group in optimizer.param_groups:
for param in group["params"]:
param_state = optimizer.state[param]
for key in param_state:
if "qmap" in key:
continue
@torch.no_grad()
def magnitude_pruning_(tensor, prune_ratio):
tensor_magnitude = torch.abs(tensor)
threshold = torch.quantile(
tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio
).to(dtype=tensor.dtype)

if key == "step" and isinstance(param_state[key], int):
param_state[key] = 0
else:
param_state[key] = torch.zeros_like(param_state[key])
mask = tensor_magnitude > threshold
tensor.mul_(mask.to(dtype=tensor.dtype))


def reset_optimizer(
optimizer: torch.optim.Optimizer,
*,
reset_params: list[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: list[str],
):
pruning_fn = partial(magnitude_pruning_, prune_ratio=0.9)
n_zeros = 0
n_total = 0

optimizer_state = optimizer.state
if isinstance(optimizer, ZeroRedundancyOptimizer):
optimizer_state = optimizer.optim.state

for param in reset_params:
param_state = optimizer_state[param]
if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
continue
for key in optimizer_state_keys:
pruning_fn(
param_state[key]
) # pruning fn has to be inplace to keep the same keys in the dict
n_total += param_state[key].numel()
n_zeros += torch.sum(param_state[key] == 0).item()

_zeroed = n_zeros / (1e-7 + n_total) * 100
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
LOG.info(f"absolute n of optimizer states zeroed: {n_zeros}")


class ReLoRACallback(TrainerCallback):
Expand Down Expand Up @@ -97,6 +126,25 @@ def on_step_begin(
"relora",
)

if "adam" in args.optim.lower():
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
else:
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")

lora_params = [
n
for n, p in model.named_parameters()
if p.requires_grad and "lora_" in n
]

model.save_pretrained(
os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
"adapter",
),
safe_serialization=True,
)
with torch.no_grad():
merge_and_save(
model,
Expand All @@ -107,7 +155,11 @@ def on_step_begin(
actually_save=is_main_process(),
cpu_offload=self.cpu_offload,
)
reset_optimizer(optimizer)
reset_optimizer(
optimizer,
reset_params=lora_params,
optimizer_state_keys=optimizer_state_keys,
)

if self.quantized:
self.last_full_model = checkpoint_folder
Expand Down Expand Up @@ -197,11 +249,13 @@ def __init__(
inner_schedule: LRScheduler,
relora_steps: int,
warmup_steps: int,
anneal_steps: int = 1,
min_lr_scale: float = 0.001,
) -> None:
self.inner_schedule = inner_schedule
self.relora_steps = relora_steps
self.warmup_steps = warmup_steps
self.anneal_steps = anneal_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)

Expand All @@ -210,10 +264,20 @@ def get_lr(self) -> float:

original = self.inner_schedule.get_lr()
step = self.last_epoch

if step < self.relora_steps:
scale = 1
else:
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
per_relora_progress = step % self.relora_steps
if per_relora_progress < self.warmup_steps:
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
elif per_relora_progress > (self.relora_steps - self.anneal_steps):
cycle_t = min(
1.0,
(self.relora_steps - per_relora_progress) / self.anneal_steps,
)
else:
cycle_t = 1
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale

if isinstance(original, Sequence):
Expand All @@ -238,7 +302,11 @@ def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:

def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:
if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
adapter = layer.active_adapter
adapter: Union[List[str], str] = layer.active_adapter
if isinstance(adapter, list):
if len(adapter) > 1:
raise ValueError("unhandled relora for multiple adapters")
adapter = adapter[0]
return (
peft.utils.transpose(
layer.lora_B[adapter].weight.detach().to(device)
Expand All @@ -248,7 +316,7 @@ def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor
* layer.scaling[adapter]
)

return layer.get_delta_weight().to(device)
raise ValueError("unhandled lora layer type")


def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:
Expand All @@ -273,9 +341,9 @@ def update_weights(
):
if reinit:
for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name)
target.reset_lora_parameters(adapter_name, True)
for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name)
target.reset_lora_parameters(adapter_name, True)

if isinstance(target, peft.tuners.lora.Linear4bit):
# This could be faster, but the quantization of Linear4bit weights occurs
Expand All @@ -286,7 +354,9 @@ def update_weights(
target.weight.data = new_weight.cpu()
target.to(device)
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
target.weight = bnb.nn.Int8Params(new_weight, requires_grad=False).to(device)
target.weight.data = (
bnb.nn.Int8Params(new_weight, requires_grad=False).to(device).data
)
else:
target.weight.data = new_weight.to(device)

Expand All @@ -304,14 +374,17 @@ def merge_and_save(

if not quantized:
for module_name, target in modules.items():
update = target.get_delta_weight(target.active_adapter).detach()
active_adapter = target.active_adapter
if isinstance(active_adapter, list):
active_adapter = active_adapter[0]
update = target.get_delta_weight(active_adapter).detach()
target.weight.data += update

if reinit:
for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name)
target.reset_lora_parameters(adapter_name, True)
for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name)
target.reset_lora_parameters(adapter_name, True)
return

os.makedirs(model_dst, exist_ok=True)
Expand Down Expand Up @@ -363,6 +436,7 @@ def merge_and_save(
LOG.info(f"saving tensors to {shard_fn}")
st.save_file(out_tensors, shard_fn, metadata={"format": "pt"})

barrier()
del in_tensors
del out_tensors
torch.cuda.empty_cache()
Expand Down
33 changes: 33 additions & 0 deletions src/axolotl/prompt_strategies/instruct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Module containing the InstructShareGPTPromptTokenizingStrategy class"""
from typing import Any, Dict, Optional

from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2


def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
strategy = InstructShareGPTPromptTokenizingStrategy(
# pylint: disable=duplicate-code
ShareGPTPrompterV2(
conversation=conversation,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
return strategy


class InstructShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row
"""

def get_conversation_thread(self, prompt):
return [
{"from": "human", "value": prompt["instruction"]},
{"from": "gpt", "value": prompt["output"]},
]
1 change: 1 addition & 0 deletions src/axolotl/utils/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def chat_templates(user_choice: str):
"""

templates = {
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
}
Expand Down
6 changes: 5 additions & 1 deletion src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,11 @@ def validate_config(cfg):
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
)

if cfg.val_set_size == 0 and (cfg.eval_steps or cfg.evaluation_strategy):
if (
cfg.val_set_size == 0
and (cfg.eval_steps or cfg.evaluation_strategy)
and not cfg.test_datasets
):
raise ValueError(
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
)
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def load_tokenized_prepared_datasets(
+ "|".join(
sorted(
[
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
f"{d.path}:{d.type}:{d.shards}:{d.conversation}{d.split}"
for d in cfg_datasets
]
)
Expand Down
15 changes: 12 additions & 3 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
import bitsandbytes as bnb
import torch
import transformers
from peft import LoftQConfig, PeftConfig, prepare_model_for_kbit_training
from peft import (
LoftQConfig,
PeftConfig,
PeftModel,
PeftModelForCausalLM,
prepare_model_for_kbit_training,
)
from peft.tuners.lora import QuantLinear
from transformers import ( # noqa: F401
AddedToken,
Expand Down Expand Up @@ -628,6 +634,9 @@ def load_model(
LOG.exception(err)
raise err

if isinstance(model, (PeftModel, PeftModelForCausalLM)):
model = model.merge_and_unload()

embeddings_len = (
math.ceil(len(tokenizer) / 32) * 32
if cfg.resize_token_embeddings_to_32x
Expand Down Expand Up @@ -782,7 +791,7 @@ def load_adapter(model, cfg, adapter, inference=False):

def load_llama_adapter(model, cfg):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import AdaptionPromptConfig, PeftModel, get_peft_model
from peft import AdaptionPromptConfig, get_peft_model

peft_config = AdaptionPromptConfig(
adapter_layers=cfg.peft_adapter.layers, # layers (L)
Expand Down Expand Up @@ -828,7 +837,7 @@ def find_all_linear_names(model):
def load_lora(model, cfg, inference=False, config_only=False):
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]

from peft import LoraConfig, PeftModel, get_peft_model
from peft import LoraConfig, get_peft_model

lora_target_modules = list(cfg.lora_target_modules or [])

Expand Down
Loading
Loading