diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 614a94b6d0..d5e36f378e 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -102,6 +102,8 @@ title: Prefix tuning - local: package_reference/prompt_tuning title: Prompt tuning + - local: package_reference/vera + title: VeRA title: Adapters - sections: - local: package_reference/merge_utils diff --git a/docs/source/package_reference/vera.md b/docs/source/package_reference/vera.md new file mode 100644 index 0000000000..9677df2742 --- /dev/null +++ b/docs/source/package_reference/vera.md @@ -0,0 +1,41 @@ + + +# VeRA: Vector-based Random Matrix Adaptation + +[VeRA](https://huggingface.co/papers/2310.11454) is a parameter-efficient fine-tuning technique that is similar to LoRA but requires even fewer extra parameters while promising similar or even better performance. As such, it is particularly useful when the parameter budget is very limited, e.g. when scaling to very large models. The reduction of the count of trainable parameters is achieved by sharing the same low-rank matrices across all layers, and only training two additional vectors per layer. + +When saving the adapter parameters, it's possible to eschew storing the low rank matrices by setting `save_projection=False` on the `VeraConfig`. In that case, these matrices will be restored based on the fixed random seed from the `projection_prng_key` argument. This cuts down on the size of the checkpoint, but we cannot guarantee reproducibility on all devices and for all future versions of PyTorch. If you want to ensure reproducibility, set `save_projection=True` (which is the default). + +VeRA currently has the following constraints: + +- All targeted parameters must have the same shape. +- Only `nn.Linear` layers are supported. +- Quantized layers are not supported. + +If these constraints don't work for your use case, use LoRA instead. + +The abstract from the paper is: + +> Low-rank adapation (LoRA) is a popular method that reduces the number of trainable parameters when finetuning large language models, but still faces acute storage challenges when scaling to even larger models or deploying numerous per-user or per-task adapted models. In this work, we present Vector-based Random Matrix Adaptation (VeRA), which significantly reduces the number of trainable parameters compared to LoRA, yet maintains the same performance. It achieves this by using a single pair of low-rank matrices shared across all layers and learning small scaling vectors instead. We demonstrate its effectiveness on the GLUE and E2E benchmarks, image classification tasks, and show its application in instruction-tuning of 7B and 13B language models. + +## VeRAConfig + +[[autodoc]] tuners.vera.config.VeraConfig + +## VeRAModel + +[[autodoc]] tuners.vera.model.VeraModel diff --git a/examples/sequence_classification/VeRA.ipynb b/examples/sequence_classification/VeRA.ipynb new file mode 100644 index 0000000000..b917618db3 --- /dev/null +++ b/examples/sequence_classification/VeRA.ipynb @@ -0,0 +1,543 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d36e1e93-ae93-4a4e-93c6-68fd868d2882", + "metadata": {}, + "source": [ + "# Using VeRA for sequence classification" + ] + }, + { + "cell_type": "markdown", + "id": "ddfc0610-55f6-4343-a950-125ccf0f45ac", + "metadata": {}, + "source": [ + "In this example, we fine-tune Roberta on a sequence classification task using VeRA." + ] + }, + { + "cell_type": "markdown", + "id": "45addd81-d4f3-4dfd-960d-3920d347f0a6", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "a9935ae2", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch.optim import AdamW\n", + "from torch.utils.data import DataLoader\n", + "from peft import (\n", + " get_peft_model,\n", + " VeraConfig,\n", + " PeftType,\n", + ")\n", + "\n", + "import evaluate\n", + "from datasets import load_dataset\n", + "from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed, AutoConfig\n", + "from tqdm import tqdm" + ] + }, + { + "cell_type": "markdown", + "id": "62c959bf-7cc2-49e0-b97e-4c10ec3b9bf3", + "metadata": {}, + "source": [ + "## Parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "e3b13308", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batch_size = 128\n", + "model_name_or_path = \"roberta-base\"\n", + "task = \"mrpc\"\n", + "peft_type = PeftType.VERA\n", + "device = \"cuda\"\n", + "num_epochs = 5 # for best results, increase this number\n", + "rank = 8 # for best results, increase this number\n", + "max_length = 128\n", + "torch.manual_seed(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "0526f571", + "metadata": {}, + "outputs": [], + "source": [ + "peft_config = VeraConfig(\n", + " task_type=\"SEQ_CLS\", \n", + " r=rank,\n", + " d_initial=0.1,\n", + " target_modules=[\"query\", \"value\"],\n", + " save_projection=True,\n", + ")\n", + "head_lr = 1e-2\n", + "vera_lr = 2e-2" + ] + }, + { + "cell_type": "markdown", + "id": "c075c5d2-a457-4f37-a7f1-94fd0d277972", + "metadata": {}, + "source": [ + "## Loading data" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7bb52cb4-d1c3-4b04-8bf0-f39ca88af139", + "metadata": {}, + "outputs": [], + "source": [ + "if any(k in model_name_or_path for k in (\"gpt\", \"opt\", \"bloom\")):\n", + " padding_side = \"left\"\n", + "else:\n", + " padding_side = \"right\"\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)\n", + "if getattr(tokenizer, \"pad_token_id\") is None:\n", + " tokenizer.pad_token_id = tokenizer.eos_token_id" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e69c5e1f-d27b-4264-a41e-fc9b99d025e6", + "metadata": {}, + "outputs": [], + "source": [ + "datasets = load_dataset(\"glue\", task)\n", + "metric = evaluate.load(\"glue\", task)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0209f778-c93b-40eb-a4e0-24c25db03980", + "metadata": {}, + "outputs": [], + "source": [ + "def tokenize_function(examples):\n", + " # max_length=None => use the model max length (it's actually the default)\n", + " outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=max_length)\n", + " return outputs\n", + "\n", + "\n", + "tokenized_datasets = datasets.map(\n", + " tokenize_function,\n", + " batched=True,\n", + " remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n", + ")\n", + "\n", + "# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n", + "# transformers library\n", + "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "7453954e-982c-46f0-b09c-589776e6d6cb", + "metadata": {}, + "outputs": [], + "source": [ + "def collate_fn(examples):\n", + " return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n", + "\n", + "\n", + "# Instantiate dataloaders.\n", + "train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size)\n", + "eval_dataloader = DataLoader(\n", + " tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f3b9b2e8-f415-4d0f-9fb4-436f1a3585ea", + "metadata": {}, + "source": [ + "## Preparing the VeRA model" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2ed5ac74", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "trainable params: 610,754 || all params: 125,257,924 || trainable%: 0.48759709605278145\n" + ] + } + ], + "source": [ + "model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True, max_length=None)\n", + "model = get_peft_model(model, peft_config)\n", + "model.print_trainable_parameters()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "0d2d0381", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = AdamW(\n", + " [\n", + " {\"params\": [p for n, p in model.named_parameters() if \"vera_lambda_\" in n], \"lr\": vera_lr},\n", + " {\"params\": [p for n, p in model.named_parameters() if \"classifier\" in n], \"lr\": head_lr},\n", + " ]\n", + ")\n", + "\n", + "# Instantiate scheduler\n", + "lr_scheduler = get_linear_schedule_with_warmup(\n", + " optimizer=optimizer,\n", + " num_warmup_steps=0.06 * (len(train_dataloader) * num_epochs),\n", + " num_training_steps=(len(train_dataloader) * num_epochs),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "c0dd5aa8-977b-4ac0-8b96-884b17bcdd00", + "metadata": {}, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "fa0e73be", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/29 [00:00 list[str]: def forward(self, *args: Any, **kwargs: Any): return self.model.forward(*args, **kwargs) + def _pre_injection_hook(self, model: nn.Module, config: PeftConfig, adapter_name: str) -> None: + r""" + A hook to be called before the adapter is injected into the model. This method can be overridden by child + classes to perform any pre-injection operations. + + Args: + model (`nn.Module`): + The model to be adapted. + config (`PeftConfig`): + The adapter config. + adapter_name (`str`): + The adapter name. + """ + pass + @abstractmethod def _prepare_adapter_config(self, peft_config: PeftConfig, model_config: dict) -> PeftConfig: r""" @@ -415,9 +431,9 @@ class BaseTunerLayer(ABC): active_adapter = None # All names of layers that may contain adapter (trainable) weights - adapter_layer_names: tuple[str] = () + adapter_layer_names: tuple[str, ...] = () # All names of other parameters that may contain adapter-related parameters - other_param_names: tuple[str] = () + other_param_names: tuple[str, ...] = () # indicates whether all adapters should be disabled _disable_adapters: bool = False diff --git a/src/peft/tuners/vera/__init__.py b/src/peft/tuners/vera/__init__.py new file mode 100644 index 0000000000..cf35881834 --- /dev/null +++ b/src/peft/tuners/vera/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import VeraConfig +from .layer import Linear, VeraLayer +from .model import VeraModel + + +__all__ = ["VeraConfig", "VeraLayer", "Linear", "VeraModel"] diff --git a/src/peft/tuners/vera/buffer_dict.py b/src/peft/tuners/vera/buffer_dict.py new file mode 100644 index 0000000000..80b2307d1d --- /dev/null +++ b/src/peft/tuners/vera/buffer_dict.py @@ -0,0 +1,160 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# Adapted from https://botorch.org/api/_modules/botorch/utils/torch.html + +# TODO: To be removed once (if) https://github.com/pytorch/pytorch/pull/37385 lands + +from __future__ import annotations + +import collections +from collections import OrderedDict + +import torch +from torch.nn import Module + + +class BufferDict(Module): + r""" + Holds buffers in a dictionary. + + BufferDict can be indexed like a regular Python dictionary, but buffers it contains are properly registered, and + will be visible by all Module methods. `torch.nn.BufferDict` is an **ordered** dictionary that respects + + * the order of insertion, and + * in `torch.nn.BufferDict.update`, the order of the merged `OrderedDict` or another `torch.nn.BufferDict` (the + argument to `torch.nn.BufferDict.update`). + + Note that `torch.nn.BufferDict.update` with other unordered mapping types (e.g., Python's plain `dict`) does not + preserve the order of the merged mapping. + + Args: + buffers (iterable, optional): + a mapping (dictionary) of (string : `torch.Tensor`) or an iterable of key-value pairs of type (string, + `torch.Tensor`) + + ```python + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.buffers = nn.BufferDict({"left": torch.randn(5, 10), "right": torch.randn(5, 10)}) + + def forward(self, x, choice): + x = self.buffers[choice].mm(x) + return x + ``` + """ + + def __init__(self, buffers=None, persistent: bool = False): + r""" + Args: + buffers (`dict`): + A mapping (dictionary) from string to `torch.Tensor`, or an iterable of key-value pairs of type + (string, `torch.Tensor`). + """ + super().__init__() + if buffers is not None: + self.update(buffers) + + self.persistent = persistent + + def __getitem__(self, key): + return self._buffers[key] + + def __setitem__(self, key, buffer): + self.register_buffer(key, buffer, persistent=self.persistent) + + def __delitem__(self, key): + del self._buffers[key] + + def __len__(self): + return len(self._buffers) + + def __iter__(self): + return iter(self._buffers.keys()) + + def __contains__(self, key): + return key in self._buffers + + def clear(self): + """Remove all items from the BufferDict.""" + self._buffers.clear() + + def pop(self, key): + r"""Remove key from the BufferDict and return its buffer. + + Args: + key (`str`): + Key to pop from the BufferDict + """ + v = self[key] + del self[key] + return v + + def keys(self): + r"""Return an iterable of the BufferDict keys.""" + return self._buffers.keys() + + def items(self): + r"""Return an iterable of the BufferDict key/value pairs.""" + return self._buffers.items() + + def values(self): + r"""Return an iterable of the BufferDict values.""" + return self._buffers.values() + + def update(self, buffers): + r""" + Update the `torch.nn.BufferDict` with the key-value pairs from a mapping or an iterable, overwriting existing + keys. + + Note: + If `buffers` is an `OrderedDict`, a `torch.nn.BufferDict`, or an iterable of key-value pairs, the order of + new elements in it is preserved. + + Args: + buffers (iterable): + a mapping (dictionary) from string to `torch.Tensor`, or an iterable of key-value pairs of type + (string, `torch.Tensor`). + """ + if not isinstance(buffers, collections.abc.Iterable): + raise TypeError( + "BuffersDict.update should be called with an " + "iterable of key/value pairs, but got " + type(buffers).__name__ + ) + + if isinstance(buffers, collections.abc.Mapping): + if isinstance(buffers, (OrderedDict, BufferDict)): + for key, buffer in buffers.items(): + self[key] = buffer + else: + for key, buffer in sorted(buffers.items()): + self[key] = buffer + else: + for j, p in enumerate(buffers): + if not isinstance(p, collections.abc.Iterable): + raise TypeError( + "BufferDict update sequence element " + "#" + str(j) + " should be Iterable; is" + type(p).__name__ + ) + if not len(p) == 2: + raise ValueError( + "BufferDict update sequence element " + "#" + str(j) + " has length " + str(len(p)) + "; 2 is required" + ) + self[p[0]] = p[1] + + def extra_repr(self): + child_lines = [] + for k, p in self._buffers.items(): + size_str = "x".join(str(size) for size in p.size()) + device_str = "" if not p.is_cuda else f" (GPU {p.get_device()})" + parastr = f"Buffer containing: [{torch.typename(p)} of size {size_str}{device_str}]" + child_lines.append(" (" + k + "): " + parastr) + tmpstr = "\n".join(child_lines) + return tmpstr + + def __call__(self, input): + raise RuntimeError("BufferDict should not be called.") diff --git a/src/peft/tuners/vera/config.py b/src/peft/tuners/vera/config.py new file mode 100644 index 0000000000..1601c8c0e6 --- /dev/null +++ b/src/peft/tuners/vera/config.py @@ -0,0 +1,157 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass, field +from typing import List, Optional, Union + +from peft.config import PeftConfig +from peft.utils import PeftType + + +@dataclass +class VeraConfig(PeftConfig): + """ + This is the configuration class to store the configuration of a [`VeraModel`]. + + Paper: https://arxiv.org/abs/2310.11454. + + Args: + r (`int`, *optional*, defaults to `256`): + VeRA parameter dimension ("rank"). Choose higher values than LoRA ranks here, since VeRA uses far fewer + parameters than LoRA (see Table 1). + target_modules (`Union[List[str], str]`): + The names of the modules to apply Vera to. Only linear layers are supported. + projection_prng_key (`int`): + Vera PRNG init key. Used for initialising vera_A and vera_B for new models or when loading a checkpoint + that did not include these projections. Defaults to `0`. + save_projection (`bool`): + Whether to save the vera_A / vera_B projections in the state dict alongside per layer lambda_b / lambda_d + weights. This will increase the size of the checkpoint, but guarantee that we can reload the checkpoint on + all system configurations. Defaults to `True`. + vera_dropout (`float`): + The dropout probability for Vera layers. + d_initial (`float`, *optional*, defaults to `0.1`): + Initial init value for `vera_lambda_d` vector used when initializing the VeRA parameters. Small values + (<=0.1) are recommended (see Table 6c in the paper). + fan_in_fan_out (`bool`): + Set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses + `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`. + bias (`str`): + Bias type for Vera. Can be 'none', 'all' or 'vera_only'. If 'all' or 'vera_only', the corresponding biases + will be updated during training. Be aware that this means that, even when disabling the adapters, the model + will not produce the same output as the base model would have without adaptation. + modules_to_save (`List[str]`): + List of modules apart from Vera layers to be set as trainable and saved in the final checkpoint. + init_weights (`bool`): + Whether to initialize the weights of the Vera layers with their default initialization. Don't change this + setting, except if you know exactly what you're doing. + layers_to_transform (`Union[List[int],int]`): + The layer indexes to transform, if this argument is specified, it will apply the Vera transformations on + the layer indexes that are specified in this list. If a single integer is passed, it will apply the Vera + transformations on the layer at this index. + layers_pattern (`str`): + The layer pattern name, used only if `layers_to_transform` is different from `None` and if the layer + pattern is not in the common layers pattern. + """ + + r: int = field(default=256, metadata={"help": "Vera attention dimension"}) + + target_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={ + "help": ( + "List of module names or regex expression of the module names to replace with Vera." + "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. " + "Only linear layers are supported." + ) + }, + ) + projection_prng_key: int = field( + default=0, + metadata={ + "help": ( + "Vera PRNG init key. Used for initialising vera_A and vera_B for new models or when loading a " + "checkpoint that did not include these projections." + ) + }, + ) + save_projection: bool = field( + default=True, + metadata={ + "help": ( + "Whether to save the vera_A / vera_B projections in the state dict alongside per layer lambda_b / " + "lambda_d weights. This will increase the size of the checkpoint, but guarantee that we can reload " + "the checkpoint on all system configurations." + ) + }, + ) + vera_dropout: float = field(default=0.0, metadata={"help": "Vera dropout"}) + d_initial: float = field(default=0.1, metadata={"help": "Initial init value for d vector."}) + fan_in_fan_out: bool = field( + default=False, + metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, + ) + bias: str = field(default="none", metadata={"help": "Bias type for Vera. Can be 'none', 'all' or 'vera_only'"}) + modules_to_save: Optional[List[str]] = field( + default=None, + metadata={ + "help": ( + "List of modules apart from Vera layers to be set as trainable and saved in the final checkpoint. For" + " example, in Sequence Classification or Token Classification tasks, the final layer" + " `classifier/score` are randomly initialized and as such need to be trainable and saved." + ) + }, + ) + init_weights: bool = field( + default=True, + metadata={ + "help": ( + "Whether to initialize the weights of the Vera layers with their default initialization. Don't change " + "this setting, except if you know exactly what you're doing." + ), + }, + ) + layers_to_transform: Optional[Union[List[int], int]] = field( + default=None, + metadata={ + "help": ( + "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers" + " indexes that are specified inside this list. If a single integer is passed, PEFT will transform only" + " the layer at this index." + ) + }, + ) + layers_pattern: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer" + " pattern is not in the common layers pattern." + ) + }, + ) + + def __post_init__(self): + self.peft_type = PeftType.VERA + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) + + if not self.save_projection: + warnings.warn( + "Specified to not save vera_A and vera_B within the state dictionary, instead they will be restored " + "using the PRNG key store in `config.projection_prng_key`. Consider setting `config.save_projection` " + "to `True` to guarantee restoring the checkpoint correctly on all system configurations." + ) diff --git a/src/peft/tuners/vera/layer.py b/src/peft/tuners/vera/layer.py new file mode 100644 index 0000000000..934e2fdb5e --- /dev/null +++ b/src/peft/tuners/vera/layer.py @@ -0,0 +1,267 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.pytorch_utils import Conv1D + +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.utils.other import transpose + +from .buffer_dict import BufferDict + + +class VeraLayer(BaseTunerLayer): + # List all names of layers that may contain adapter weights + adapter_layer_names = ("vera_lambda_b", "vera_lambda_d") + other_param_names = ("vera_A", "vera_B") + + def __init__(self, base_layer: nn.Module, **kwargs): + self.base_layer = base_layer + self.r = {} + self.vera_dropout = nn.ModuleDict({}) + + # For storing vector scale + self.vera_lambda_b = nn.ParameterDict({}) + self.vera_lambda_d = nn.ParameterDict({}) + + # Stores a reference to the vera_A/B BufferDict. + # Set to `None` otherwise to avoid computation with random weights + self.vera_A: Optional[BufferDict] = None + self.vera_B: Optional[BufferDict] = None + + # Mark the weight as unmerged + self._disable_adapters = False + self.merged_adapters = [] + + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + in_features, out_features = base_layer.in_features, base_layer.out_features + elif isinstance(base_layer, Conv1D): + in_features, out_features = ( + base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape + ) + + self.in_features = in_features + self.out_features = out_features + self.kwargs = kwargs + + @property + def merged(self) -> bool: + return bool(self.merged_adapters) + + def update_layer( + self, + adapter_name, + vera_A: BufferDict, + vera_B: BufferDict, + r, + vera_dropout, + init_weights, + d_initial: float = 0.1, + ): + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + self.r[adapter_name] = r + if vera_dropout > 0.0: + vera_dropout_layer = nn.Dropout(p=vera_dropout) + else: + vera_dropout_layer = nn.Identity() + + self.vera_dropout.update(nn.ModuleDict({adapter_name: vera_dropout_layer})) + # Actual trainable parameters + self.vera_lambda_b[adapter_name] = nn.Parameter(torch.ones(self.out_features), requires_grad=True) + self.vera_lambda_d[adapter_name] = nn.Parameter(torch.randn(r), requires_grad=True) + + # non trainable references to vera_A/B buffers + self.vera_A = vera_A + self.vera_B = vera_B + if adapter_name not in vera_A: + # This means that this is not the first VeRA adapter. We have to add an entry in the dict for this adapter. + if len(self.vera_A) < 1: + raise ValueError( + "The `vera_A` and `vera_B` buffers are empty. This should not happen. Please report this issue." + ) + # we can take any of the existing adapter's parameters, as they should all be identical + vera_A_param = list(self.vera_A.values())[0] + vera_B_param = list(self.vera_B.values())[0] + self.vera_A[adapter_name] = vera_A_param + self.vera_B[adapter_name] = vera_B_param + + if init_weights: + self.reset_vera_parameters(adapter_name, d_initial=d_initial) + + weight = getattr(self.get_base_layer(), "weight", None) + if weight is not None: + # the layer is already completely initialized, this is an update + if weight.dtype.is_floating_point or weight.dtype.is_complex: + self.to(weight.device, dtype=weight.dtype) + else: + self.to(weight.device) + + self.set_adapter(self.active_adapters) + + def reset_vera_parameters(self, adapter_name, d_initial: float = 0.1): + if adapter_name in self.vera_lambda_d.keys(): + with torch.no_grad(): + nn.init.zeros_(self.vera_lambda_d[adapter_name]).fill_(d_initial) + nn.init.zeros_(self.vera_lambda_b[adapter_name]) + + +class Linear(nn.Linear, VeraLayer): + # Vera implemented in a dense layer + def __init__( + self, + base_layer, + vera_A: BufferDict, + vera_B: BufferDict, + adapter_name: str, + r: int = 0, + vera_dropout: float = 0.0, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + is_target_conv_1d_layer: bool = False, + init_weights: bool = True, + d_initial: float = 0.1, + **kwargs, + ) -> None: + # this gets the init from nn.Linear's super perspective, i.e. nn.Module.__init__, which should always be called + super(nn.Linear, self).__init__() + VeraLayer.__init__(self, base_layer, **kwargs) + self.fan_in_fan_out = fan_in_fan_out + + self._active_adapter = adapter_name + self.update_layer(adapter_name, vera_A, vera_B, r, vera_dropout, init_weights, d_initial=d_initial) + self.is_target_conv_1d_layer = is_target_conv_1d_layer + + def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.vera_lambda_d.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weights = base_layer.weight.data.clone() + + orig_weights += self.get_delta_weight(active_adapter) + + if not torch.isfinite(orig_weights).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + base_layer.weight.data = orig_weights + else: + base_layer.weight.data += self.get_delta_weight(active_adapter) + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.vera_lambda_d.keys(): + self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) + + def get_delta_weight(self, adapter) -> torch.Tensor: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + vera_A = self.vera_A[adapter] + vera_B = self.vera_B[adapter] + + device = vera_B.device + dtype = vera_B.dtype + + # In case users wants to merge the adapter weights that are in + # float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16. + cast_to_fp32 = device.type == "cpu" and dtype == torch.float16 + + lambda_d = self.vera_lambda_d[adapter] + lambda_b = self.vera_lambda_b[adapter] + + if cast_to_fp32: + vera_A = vera_A.float() + vera_B = vera_B.float() + lambda_d = lambda_d.float() + lambda_b = lambda_b.float() + + lambda_b = lambda_b.unsqueeze(-1) + lambda_d = lambda_d.unsqueeze(-1) + output_tensor = transpose((lambda_b * vera_B) @ (lambda_d * vera_A), self.fan_in_fan_out) + + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + + # cast back the weights + # TODO: why? + self.vera_lambda_d[adapter].data = lambda_d.to(dtype) + self.vera_lambda_b[adapter].data = lambda_b.to(dtype) + + return output_tensor + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + for active_adapter in self.active_adapters: + if active_adapter not in self.vera_lambda_d.keys(): + continue + + lambda_d = self.vera_lambda_d[active_adapter] + lambda_b = self.vera_lambda_b[active_adapter] + + vera_A = self.vera_A[active_adapter] + vera_B = self.vera_B[active_adapter] + + dropout = self.vera_dropout[active_adapter] + x = x.to(lambda_d.dtype) + result = result + lambda_b * F.linear(lambda_d * F.linear(dropout(x), vera_A), vera_B) + + result = result.to(previous_dtype) + return result diff --git a/src/peft/tuners/vera/model.py b/src/peft/tuners/vera/model.py new file mode 100644 index 0000000000..2ecd1c9ab8 --- /dev/null +++ b/src/peft/tuners/vera/model.py @@ -0,0 +1,474 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +import warnings +from dataclasses import asdict +from enum import Enum +from typing import Optional, Union + +import torch +import torch.nn as nn +from torch.nn.init import _calculate_correct_fan +from tqdm import tqdm +from transformers.pytorch_utils import Conv1D + +from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists +from peft.utils import ( + TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING, + ModulesToSaveWrapper, + _get_submodules, +) + +from ..tuners_utils import _maybe_include_all_linear_layers +from .buffer_dict import BufferDict +from .config import VeraConfig +from .layer import Linear, VeraLayer + + +def _kaiming_init( + tensor_or_shape: Union[torch.Tensor, tuple[int, ...]], + generator: torch.Generator, +) -> torch.Tensor: + """ + Kaiming Uniform Initialisation adapted to accept a `torch.Generator` object for PRNG. + + Args: + tensor_or_shape (`Union[torch.Tensor, tuple[int, ...]]`): + Tensor to initialise, or shape of new tensor to create and then initialise. + generator: (`torch.Generator`): + Generator object that manages the state of the PRNG algorithm in use. + + Returns: + `torch.Tensor`: The initialised tensor. + """ + if isinstance(tensor_or_shape, tuple): + tensor = torch.empty(tensor_or_shape) + else: + tensor = tensor_or_shape + fan = _calculate_correct_fan(tensor, "fan_in") + gain = math.sqrt(2) + std = gain / math.sqrt(fan) + bound = math.sqrt(3.0) * std + + with torch.no_grad(): + return tensor.uniform_(-bound, bound, generator=generator) + + +class VeraModel(BaseTuner): + """ + Creates Vector-based Random Matrix Adaptation (Vera) model from a pretrained transformers model. + + Args: + model ([`~transformers.PreTrainedModel`]): The model to be adapted. + config ([`VeraConfig`]): The configuration of the Vera model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + + Returns: + `torch.nn.Module`: The Vera model. + + Example: + + ```py + >>> from transformers import AutoModelForCausalLM + >>> from peft import VeraConfig, get_peft_model + + >>> base_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") + >>> config = VeraConfig(r=128) + >>> model = get_peft_model(base_model, config) + ``` + + **Attributes**: + - **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted. + - **peft_config** ([`VeraConfig`]): The configuration of the Vera model. + """ + + prefix: str = "vera_lambda" + + def __init__(self, model, config, adapter_name) -> None: + super().__init__(model, config, adapter_name) + + def _find_first_dim(self, config) -> tuple[int, int]: + """ + Finds the first linear layer that has been wrapped with Vera, and extract the input and output dimension. + + This will be used for determining the size of the shared vera_A and vera_B matrices. + + This will throw an error if there are multiple layers of the same type with different shapes. + """ + model_config = getattr(self.model, "config", {"model_type": "custom"}) + if hasattr(model_config, "to_dict"): + model_config = model_config.to_dict() + + peft_config = self._prepare_adapter_config(config, model_config) + peft_config = _maybe_include_all_linear_layers(peft_config, self.model) + + first_shape = None + for key, module in self.model.named_modules(): + if not self._check_target_module_exists(peft_config, key): + continue + + if isinstance(module, (nn.Linear, Conv1D)): + module_shape = tuple(module.weight.shape) + if isinstance(module, Conv1D): + module_shape = module_shape[::-1] + else: + continue + + if first_shape is None: + first_shape = module_shape + continue + + if module_shape != first_shape: + raise ValueError( + "Multiple target layers with different dimensions were specified. VeRA only supports a " + f"single dimension size. Expected shape {first_shape}, got {module_shape}." + ) + + if first_shape is None: + msg = "No layers types compatible with VeRA were found. Please check `peft_config.target_modules`." + raise ValueError(msg) + + return first_shape + + def _init_vera_A_vera_B(self, config: VeraConfig, adapter_name: str) -> None: + first_linear_out_dim, first_linear_in_dim = self._find_first_dim(config) + + # use of persistent to exclude vera_A and vera_B from the state dict if we choose not to save them. + self.vera_A = BufferDict({}, persistent=config.save_projection) + self.vera_B = BufferDict({}, persistent=config.save_projection) + + # deterministic init of vera_A and vera_B if we know the key + generator = torch.Generator(device="cpu").manual_seed(config.projection_prng_key) + vera_A = _kaiming_init((config.r, first_linear_in_dim), generator=generator) + vera_B = _kaiming_init((first_linear_out_dim, config.r), generator=generator) + self.vera_A[adapter_name] = vera_A + self.vera_B[adapter_name] = vera_B + + def _pre_injection_hook(self, model: nn.Module, config: VeraConfig, adapter_name: str) -> None: + self._init_vera_A_vera_B(config, adapter_name) + + def _check_new_adapter_config(self, config: VeraConfig) -> None: + """ + A helper method to check the config when a new adapter is being added. + + Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. + + """ + # the below todo is copied from LoRA + # TODO: there should be a check if any of the existing adapters actually has bias != "none", or else the check + # does not fully correspond to the error message. + if (len(self.peft_config) > 1) and (config.bias != "none"): + raise ValueError( + f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, " + "set bias to 'none' for all adapters." + ) + + for existing_config in self.peft_config.values(): + if existing_config is config: + # skip the current config + continue + + if existing_config.projection_prng_key != config.projection_prng_key: + raise ValueError( + f"Vera PRNG initialisation key must be the same for all adapters. Got {config.projection_prng_key=} but " + f"previous config had {existing_config.projection_prng_key}." + ) + + save_project_unique_values = sorted({config.save_projection for config in self.peft_config.values()}) + if len(save_project_unique_values) > 1: + raise ValueError( + "VeRA projection weights must be saved for all adapters or none, but got multiple different values: " + f"{save_project_unique_values}" + ) + + @staticmethod + def _check_target_module_exists(vera_config, key): + return check_target_module_exists(vera_config, key) + + def _create_and_replace( + self, + vera_config, + adapter_name, + target, + target_name, + parent, + current_key, + **optional_kwargs, + ): + if current_key is None: + raise ValueError("Current Key shouldn't be `None`") + + r = vera_config.r + bias = hasattr(target, "bias") and target.bias is not None + kwargs = { + "r": r, + "vera_dropout": vera_config.vera_dropout, + "fan_in_fan_out": vera_config.fan_in_fan_out, + "init_weights": vera_config.init_weights, + } + kwargs["bias"] = bias + # TODO: add quantization support + + if isinstance(target, Linear): + target.update_layer( + adapter_name, + self.vera_A, + self.vera_B, + r, + vera_config.vera_dropout, + vera_config.init_weights, + d_initial=vera_config.d_initial, + ) + else: + new_module = self._create_new_module(vera_config, self.vera_A, self.vera_B, adapter_name, target, **kwargs) + if adapter_name not in self.active_adapter: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + + @staticmethod + def _replace_module(parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + # It's not necessary to set requires_grad here, as that is handled by + # _mark_only_adapters_as_trainable + + # child layer wraps the original module, unpack it + if hasattr(child, "base_layer"): + child = child.base_layer + + if not hasattr(new_module, "base_layer"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias + + if getattr(child, "state", None) is not None: + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) + + # dispatch to correct device + for name, module in new_module.named_modules(): + if "vera_" in name: + module.to(child.weight.device) + + def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: + for n, p in model.named_parameters(): + if self.prefix not in n: + p.requires_grad = False + + for active_adapter in self.active_adapters: + bias = self.peft_config[active_adapter].bias + if bias == "none": + continue + + if bias == "all": + for n, p in model.named_parameters(): + if "bias" in n: + p.requires_grad = True + elif bias == "vera_only": + for m in model.modules(): + if isinstance(m, VeraLayer) and hasattr(m, "bias") and m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError(f"Requested bias: {bias}, is not implemented.") + + @staticmethod + def _create_new_module(vera_config, vera_A, vera_B, adapter_name, target, **kwargs): + bias = kwargs.pop("bias", False) + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if isinstance(target_base_layer, torch.nn.Linear): + if kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " + "Setting fan_in_fan_out to False." + ) + kwargs["fan_in_fan_out"] = vera_config.fan_in_fan_out = False + elif isinstance(target_base_layer, Conv1D): + kwargs["is_target_conv_1d_layer"] = True + if not kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to False but the target module is `Conv1D`. " + "Setting fan_in_fan_out to True." + ) + kwargs["fan_in_fan_out"] = vera_config.fan_in_fan_out = True + else: + raise ValueError( + f"Target module {target} is not supported. Currently, only the following modules are supported: " + "`torch.nn.Linear`, `transformers.pytorch_utils.Conv1D`." + ) + new_module = Linear( + target, + vera_A, + vera_B, + adapter_name, + bias=bias, + d_initial=vera_config.d_initial, + **kwargs, + ) + + return new_module + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.model, name) + + def get_peft_config_as_dict(self, inference: bool = False): + config_dict = {} + for key, value in self.peft_config.items(): + config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} + if inference: + config["inference_mode"] = True + config_dict[key] = config + return config + + def _set_adapter_layers(self, enabled=True): + for module in self.model.modules(): + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + module.enable_adapters(enabled) + + def enable_adapter_layers(self): + self._set_adapter_layers(enabled=True) + + def disable_adapter_layers(self): + for active_adapter in self.active_adapters: + val = self.peft_config[active_adapter].bias + if val != "none": + msg = ( + f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same " + "output as the the base model would without adaption." + ) + warnings.warn(msg) + self._set_adapter_layers(enabled=False) + + def set_adapter(self, adapter_name): + for module in self.model.modules(): + if isinstance(module, VeraLayer): + if module.merged: + warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") + module.unmerge() + module.set_adapter(adapter_name) + self.active_adapter = adapter_name + + @staticmethod + def _prepare_adapter_config(peft_config, model_config): + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + peft_config.target_modules = set( + TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING[model_config["model_type"]] + ) + return peft_config + + def _unload_and_optionally_merge( + self, + merge=True, + progressbar: bool = False, + safe_merge: bool = False, + adapter_names: Optional[list[str]] = None, + ): + # we cannot use self.prefix as we want to include non-trainable vera parameters + key_list = [key for key, _ in self.model.named_modules() if "vera" not in key] + desc = "Unloading " + ("and merging " if merge else "") + "model" + for key in tqdm(key_list, disable=not progressbar, desc=desc): + try: + parent, target, target_name = _get_submodules(self.model, key) + except AttributeError: + continue + + if hasattr(target, "base_layer"): + if merge: + target.merge(safe_merge=safe_merge, adapter_names=adapter_names) + + self._replace_module(parent, target_name, target.get_base_layer(), target) + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` + setattr(parent, target_name, target.modules_to_save[target.active_adapter]) + + return self.model + + def delete_adapter(self, adapter_name: str): + """ + Deletes an existing adapter. + + Args: + adapter_name (str): Name of the adapter to be deleted. + """ + if adapter_name not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter_name} does not exist") + del self.peft_config[adapter_name] + + # we cannot use self.prefix as we want to include non-trainable vera parameters + key_list = [key for key, _ in self.model.named_modules() if "vera" not in key] + new_adapter = None + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, VeraLayer): + target.delete_adapter(adapter_name) + if new_adapter is None: + new_adapter = target.active_adapter[:] + + self.active_adapter = new_adapter or [] + + def merge_and_unload( + self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ): + r""" + This method merges the Vera layers into the base model. This is needed if someone wants to use the base model + as a standalone model. + + Args: + progressbar (`bool`): + whether to show a progressbar indicating the unload and merge process + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + + Example: + + ```py + >>> from transformers import AutoModelForCausalLM + >>> from peft import PeftModel + + >>> base_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-40b") + >>> peft_model_id = "smangrul/falcon-40B-int4-peft-lora-sfttrainer-sample" + >>> model = PeftModel.from_pretrained(base_model, peft_model_id) + >>> merged_model = model.merge_and_unload() + ``` + """ + return self._unload_and_optionally_merge( + progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names + ) + + def unload(self): + """ + Gets back the base model by removing all the Vera modules without merging. This gives back the original base + model. + """ + return self._unload_and_optionally_merge(merge=False) diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index d135524fc9..02b03cdd52 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -26,6 +26,7 @@ TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING, CONFIG_NAME, WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, diff --git a/src/peft/utils/constants.py b/src/peft/utils/constants.py index 41047431ce..5dc7de326d 100644 --- a/src/peft/utils/constants.py +++ b/src/peft/utils/constants.py @@ -150,6 +150,41 @@ def starcoder_model_postprocess_past_key_value(past_key_values): # "layoutlm": ["query", "value"], } +TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING = { + "t5": ["q", "v"], + "mt5": ["q", "v"], + "bart": ["q_proj", "v_proj"], + "gpt2": ["c_attn"], + "bloom": ["query_key_value"], + "blip-2": ["q", "v", "q_proj", "v_proj"], + "opt": ["q_proj", "v_proj"], + "gptj": ["q_proj", "v_proj"], + "gpt_neox": ["query_key_value"], + "gpt_neo": ["q_proj", "v_proj"], + "bert": ["query", "value"], + "roberta": ["query", "value"], + "xlm-roberta": ["query", "value"], + "electra": ["query", "value"], + "deberta-v2": ["query_proj", "value_proj"], + "deberta": ["in_proj"], + "layoutlm": ["query", "value"], + "llama": ["q_proj", "v_proj"], + "chatglm": ["query_key_value"], + "gpt_bigcode": ["c_attn"], + "mpt": ["Wqkv"], + "RefinedWebModel": ["query_key_value"], + "RefinedWeb": ["query_key_value"], + "falcon": ["query_key_value"], + # "btlm": ["c_proj", "c_attn"], # tested, does not work because of different shapes + "codegen": ["qkv_proj"], + # "mistral": ["q_proj", "v_proj"], # tested, does not work because of different shapes + # "mixtral": ["q_proj", "v_proj"], # tested, does not work because of different shapes + "stablelm": ["q_proj", "v_proj"], + # "phi": ["q_proj", "v_proj", "fc1", "fc2"], # tested, does not work because of different shapes + "phi": ["q_proj", "v_proj"], + # "gemma": ["q_proj", "v_proj"], # tested, does not work because of different shapes +} + WEIGHTS_NAME = "adapter_model.bin" SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" CONFIG_NAME = "adapter_config.json" diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index df38a7a43c..503db7d9d3 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -37,6 +37,7 @@ TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, + TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING, WEIGHTS_NAME, bloom_model_postprocess_past_key_value, starcoder_model_postprocess_past_key_value, @@ -52,6 +53,7 @@ "TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING", + "TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING", "WEIGHTS_NAME", "INCLUDE_LINEAR_LAYERS_SHORTHAND", "bloom_model_postprocess_past_key_value", diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index a2d276c10b..32fcbb1c61 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -50,6 +50,7 @@ class PeftType(str, enum.Enum): LOKR = "LOKR" OFT = "OFT" POLY = "POLY" + VERA = "VERA" class TaskType(str, enum.Enum): diff --git a/src/peft/utils/save_and_load.py b/src/peft/utils/save_and_load.py index 5b1662eb75..e8023fca6f 100644 --- a/src/peft/utils/save_and_load.py +++ b/src/peft/utils/save_and_load.py @@ -123,6 +123,7 @@ def get_peft_model_state_dict( elif config.peft_type == PeftType.ADAPTION_PROMPT: to_return = {k: state_dict[k] for k in state_dict if k.split(".")[-1].startswith("adaption_")} + elif config.is_prompt_learning: to_return = {} if config.peft_type == PeftType.MULTITASK_PROMPT_TUNING: @@ -135,14 +136,32 @@ def get_peft_model_state_dict( else: prompt_embeddings = model.get_prompt_embedding_to_save(adapter_name) to_return["prompt_embeddings"] = prompt_embeddings + elif config.peft_type == PeftType.IA3: to_return = {k: state_dict[k] for k in state_dict if "ia3_" in k} + elif config.peft_type == PeftType.OFT: to_return = {k: state_dict[k] for k in state_dict if "oft_" in k} + elif config.peft_type == PeftType.POLY: to_return = {k: state_dict[k] for k in state_dict if "poly_" in k} + + elif config.peft_type == PeftType.VERA: + to_return = {k: state_dict[k] for k in state_dict if "vera_lambda_" in k} + if config.save_projection: + # TODO: adding vera_A and vera_B to `self.get_base_layer` would + # make name to match here difficult to predict. + if f"base_model.vera_A.{adapter_name}" not in state_dict: + raise ValueError( + "Model was initialised to not save vera_A and vera_B but config now specifies to save projection!" + " Set `config.save_projection` to `False`." + ) + to_return["base_model.vera_A." + adapter_name] = state_dict["base_model.vera_A." + adapter_name] + to_return["base_model.vera_B." + adapter_name] = state_dict["base_model.vera_B." + adapter_name] + else: - raise NotImplementedError + raise ValueError(f"Unknown PEFT type passed: {config.peft_type}") + if getattr(model, "modules_to_save", None) is not None: for key, value in state_dict.items(): if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save): @@ -271,6 +290,7 @@ def set_peft_model_state_dict( PeftType.OFT, PeftType.POLY, PeftType.BOFT, + PeftType.VERA, ): peft_model_state_dict = {} parameter_prefix = { @@ -282,6 +302,7 @@ def set_peft_model_state_dict( PeftType.OFT: "oft_", PeftType.POLY: "poly_", PeftType.BOFT: "boft_", + PeftType.VERA: "vera_lambda_", }[config.peft_type] for k, v in state_dict.items(): if parameter_prefix in k: @@ -294,10 +315,28 @@ def set_peft_model_state_dict( peft_model_state_dict[k] = v else: peft_model_state_dict[k] = v + if config.peft_type == PeftType.ADALORA: rank_pattern = config.rank_pattern if rank_pattern is not None: model.resize_modules_by_rank_pattern(rank_pattern, adapter_name) + elif config.peft_type == PeftType.VERA: + if config.save_projection and "base_model.vera_A" not in peft_model_state_dict: + raise ValueError( + "Specified to load vera_A and vera_B from state dictionary however they were not present!" + ) + elif not config.save_projection and "base_model.vera_A" in peft_model_state_dict: + warnings.warn( + "Specified to not load vera_A and vera_B from state dictionary however they are present in state" + " dictionary! Consider using them to ensure checkpoint loading is correct on all platforms using" + " `peft_config.save_projection = True`" + ) + elif not config.save_projection: # and no vera_A in state dictionary + warnings.warn( + "Specified to not load vera_A and vera_B from state dictionary. This means we will be relying on" + " PRNG initialisation to restore these projections using `config.projection_prng_key`, which may" + " not be accurate on all system configurations." + ) elif config.is_prompt_learning or config.peft_type == PeftType.ADAPTION_PROMPT: peft_model_state_dict = state_dict else: diff --git a/tests/test_config.py b/tests/test_config.py index cea7e5efe7..a3dfa9d182 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -37,6 +37,7 @@ PromptEncoder, PromptEncoderConfig, PromptTuningConfig, + VeraConfig, ) @@ -55,6 +56,7 @@ OFTConfig, PolyConfig, BOFTConfig, + VeraConfig, ) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 9fcec72a4c..d2a007a936 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -41,6 +41,7 @@ OFTConfig, PeftModel, TaskType, + VeraConfig, get_peft_model, ) from peft.tuners.tuners_utils import BaseTunerLayer @@ -316,6 +317,24 @@ BOFTConfig, {"target_modules": ["conv2d"], "boft_block_size": 2, "boft_block_num": 0, "boft_n_butterfly_factor": 3}, ), + ######## + # VeRA # + ######## + ("Vanilla MLP 1 VeRA", "MLP", VeraConfig, {"target_modules": "lin0"}), + ("Vanilla MLP 2 VeRA", "MLP", VeraConfig, {"target_modules": ["lin0"]}), + ("Vanilla MLP 3 VeRA", "MLP", VeraConfig, {"target_modules": ["lin1"]}), + ( + "Vanilla MLP 5 VeRA", + "MLP", + VeraConfig, + {"target_modules": ["lin0"], "modules_to_save": ["lin1"]}, + ), + ( + "Embedding + transformers Conv1D 1 VeRA", + "EmbConv1D", + VeraConfig, + {"target_modules": ["conv1d"]}, + ), ] MULTIPLE_ACTIVE_ADAPTERS_TEST_CASES = [ @@ -385,6 +404,7 @@ LoKrConfig: "lokr_", OFTConfig: "oft_", BOFTConfig: "boft_", + VeraConfig: "vera_lambda_", } @@ -2124,7 +2144,7 @@ def test_requires_grad_loha_different_targets(self): config1 = LoHaConfig(target_modules=["lin1"], inference_mode=True) peft_model.add_adapter("adapter1", config1) - # active pter is still "default" + # active adapter is still "default" self.check_requires_grad( peft_model, "base_model.model.lin0.hada_w1_a.default", @@ -2225,7 +2245,7 @@ def test_requires_grad_lokr_different_targets(self): config1 = LoKrConfig(target_modules=["lin1"], inference_mode=True) peft_model.add_adapter("adapter1", config1) - # active pter is still "default" + # active adapter is still "default" self.check_requires_grad( peft_model, "base_model.model.lin0.lokr_w1.default", @@ -2310,7 +2330,7 @@ def test_requires_grad_oft_different_targets(self): config1 = OFTConfig(target_modules=["lin1"], inference_mode=True) peft_model.add_adapter("adapter1", config1) - # active pter is still "default" + # active adapter is still "default" self.check_requires_grad( peft_model, "base_model.model.lin0.oft_r.default", @@ -2464,6 +2484,148 @@ def test_requires_grad_boft_same_targets(self): "base_model.model.lin1.boft_s.adapter1", ) + def test_requires_grad_vera_different_targets(self): + # Test two different VeRA adapters that target different modules. Most notably, ensure that vera_A and vera_B + # don't require grads. + + # requires a model with at least 2 layers with the same shapes + class MLP2(nn.Module): + def __init__(self, bias=True): + super().__init__() + self.relu = nn.ReLU() + self.lin0 = nn.Linear(10, 20, bias=bias) + self.lin1 = nn.Linear(20, 20, bias=bias) # lin1 and lin2 have same shape + self.lin2 = nn.Linear(20, 20, bias=bias) + self.lin3 = nn.Linear(20, 2, bias=bias) + self.sm = nn.LogSoftmax(dim=-1) + + def forward(self, X): + X = X.float() + X = self.lin0(X) + X = self.relu(X) + X = self.lin1(X) + X = self.relu(X) + X = self.lin2(X) + X = self.relu(X) + X = self.lin3(X) + X = self.sm(X) + return X + + config0 = VeraConfig(target_modules=["lin1"]) + peft_model = get_peft_model(MLP2(), config0) + + config1 = VeraConfig(target_modules=["lin2"]) + peft_model.add_adapter("adapter1", config1) + + # active adapter is still "default" + self.check_requires_grad( + peft_model, + "base_model.model.lin1.vera_lambda_b.default", + "base_model.model.lin1.vera_lambda_d.default", + ) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.check_requires_grad( + peft_model, + "base_model.model.lin1.vera_lambda_b.default", + "base_model.model.lin1.vera_lambda_d.default", + ) + + # change activate adapter to adapter1 + peft_model.set_adapter("adapter1") + self.check_requires_grad( + peft_model, + "base_model.model.lin2.vera_lambda_b.adapter1", + "base_model.model.lin2.vera_lambda_d.adapter1", + ) + + # disable all adapters + with peft_model.disable_adapter(): + self.check_requires_grad(peft_model) + + # after context is exited, return to the previous state + self.check_requires_grad( + peft_model, + "base_model.model.lin2.vera_lambda_b.adapter1", + "base_model.model.lin2.vera_lambda_d.adapter1", + ) + + def test_requires_grad_vera_same_targets(self): + # Test two different VeRA adapters that target the same module. Most notably, ensure that vera_A and vera_B + # don't require grads. + + # requires a model with at least 2 layers with the same shapes + class MLP2(nn.Module): + def __init__(self, bias=True): + super().__init__() + self.relu = nn.ReLU() + self.lin0 = nn.Linear(10, 20, bias=bias) + self.lin1 = nn.Linear(20, 20, bias=bias) # lin1 and lin2 have same shape + self.lin2 = nn.Linear(20, 20, bias=bias) + self.lin3 = nn.Linear(20, 2, bias=bias) + self.sm = nn.LogSoftmax(dim=-1) + + def forward(self, X): + X = X.float() + X = self.lin0(X) + X = self.relu(X) + X = self.lin1(X) + X = self.relu(X) + X = self.lin2(X) + X = self.relu(X) + X = self.lin3(X) + X = self.sm(X) + return X + + config0 = VeraConfig(target_modules=["lin1", "lin2"]) + peft_model = get_peft_model(MLP2(), config0) + + config1 = VeraConfig(target_modules=["lin1", "lin2"]) + peft_model.add_adapter("adapter1", config1) + + # active adapter is still "default" + self.check_requires_grad( + peft_model, + "base_model.model.lin1.vera_lambda_b.default", + "base_model.model.lin1.vera_lambda_d.default", + "base_model.model.lin2.vera_lambda_b.default", + "base_model.model.lin2.vera_lambda_d.default", + ) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.check_requires_grad( + peft_model, + "base_model.model.lin1.vera_lambda_b.default", + "base_model.model.lin1.vera_lambda_d.default", + "base_model.model.lin2.vera_lambda_b.default", + "base_model.model.lin2.vera_lambda_d.default", + ) + + # change activate adapter to adapter1 + peft_model.set_adapter("adapter1") + self.check_requires_grad( + peft_model, + "base_model.model.lin1.vera_lambda_b.adapter1", + "base_model.model.lin1.vera_lambda_d.adapter1", + "base_model.model.lin2.vera_lambda_b.adapter1", + "base_model.model.lin2.vera_lambda_d.adapter1", + ) + + # disable all adapters + with peft_model.disable_adapter(): + self.check_requires_grad(peft_model) + + # after context is exited, return to the previous state + self.check_requires_grad( + peft_model, + "base_model.model.lin1.vera_lambda_b.adapter1", + "base_model.model.lin1.vera_lambda_d.adapter1", + "base_model.model.lin2.vera_lambda_b.adapter1", + "base_model.model.lin2.vera_lambda_d.adapter1", + ) + class TestMixedAdapterBatches: torch_device = infer_device() diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index 642cc99e6a..00e0462537 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -171,6 +171,7 @@ def test_from_pretrained_config_construction(self, test_name, model_id, config_c "lora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, "boft_kwargs": {"init_weights": [False]}, + "vera_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, ) @@ -185,6 +186,7 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): "lora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, "boft_kwargs": {"init_weights": [False]}, + "vera_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, filter_params_func=skip_boft_and_gpt2, @@ -280,6 +282,7 @@ def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, co "adalora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, "boft_kwargs": {"init_weights": [False]}, + "vera_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, filter_params_func=skip_adalora_or_boft_and_gpt2, @@ -313,6 +316,7 @@ def test_training_prompt_learning_tasks(self, test_name, model_id, config_cls, c "ia3_kwargs": {"init_ia3_weights": [False]}, "adalora_kwargs": {"init_lora_weights": [False]}, "boft_kwargs": {"init_weights": [False]}, + "vera_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, filter_params_func=skip_boft_and_gpt2, diff --git a/tests/test_encoder_decoder_models.py b/tests/test_encoder_decoder_models.py index 1dbb689a7f..c77f46e9e9 100644 --- a/tests/test_encoder_decoder_models.py +++ b/tests/test_encoder_decoder_models.py @@ -92,6 +92,7 @@ def test_from_pretrained_config_construction(self, test_name, model_id, config_c "model_ids": PEFT_ENCODER_DECODER_MODELS_TO_TEST, "lora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, + "vera_kwargs": {"init_weights": [False]}, "task_type": "SEQ_2_SEQ_LM", }, ) @@ -170,6 +171,7 @@ def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, co "adalora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, "boft_kwargs": {"init_weights": [False]}, + "vera_kwargs": {"init_weights": [False]}, "task_type": "SEQ_2_SEQ_LM", }, ) @@ -201,6 +203,7 @@ def test_training_prompt_learning_tasks(self, test_name, model_id, config_cls, c "adalora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, "boft_kwargs": {"init_weights": [False]}, + "vera_kwargs": {"init_weights": [False]}, "task_type": "SEQ_2_SEQ_LM", }, ) diff --git a/tests/test_feature_extraction_models.py b/tests/test_feature_extraction_models.py index a0cf85f2c2..04102655dd 100644 --- a/tests/test_feature_extraction_models.py +++ b/tests/test_feature_extraction_models.py @@ -44,12 +44,10 @@ def skip_non_prompt_tuning(test_list): def skip_deberta_lora_tests(test_list): r""" - Skip tests that are checkpointing with lora/ia3/boft tests for Deberta models (couldn't find much info on the - error) + Skip tests that are checkpointing with lora/ia3/boft/vera for Deberta models (couldn't find much info on the error) """ - return [ - test for test in test_list if not (any(k in test[0] for k in ["lora", "ia3", "boft"]) and "Deberta" in test[0]) - ] + to_skip = ["lora", "ia3", "boft", "vera"] + return [test for test in test_list if not (any(k in test[0] for k in to_skip) and "Deberta" in test[0])] def skip_deberta_pt_tests(test_list): @@ -111,6 +109,7 @@ def test_from_pretrained_config_construction(self, test_name, model_id, config_c "lora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, "boft_kwargs": {"init_weights": [False]}, + "vera_kwargs": {"init_weights": [False]}, "task_type": "FEATURE_EXTRACTION", }, ) @@ -162,6 +161,7 @@ def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_k "adalora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, "boft_kwargs": {"init_weights": [False]}, + "vera_kwargs": {"init_weights": [False]}, "task_type": "FEATURE_EXTRACTION", }, ) diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 56ea47af63..44ba12fbc7 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re + import pytest import torch from scipy import stats from torch import nn -from peft import LoraConfig, PromptTuningConfig, get_peft_model +from peft import LoraConfig, PromptTuningConfig, VeraConfig, get_peft_model from peft.utils import infer_device @@ -365,3 +367,18 @@ def test_use_prompt_tuning_init_text_raises(self): PromptTuningConfig(prompt_tuning_init="TEXT", prompt_tuning_init_text="prompt tuning init text") with pytest.raises(ValueError, match="When prompt_tuning_init='TEXT', prompt_tuning_init_text can't be None"): PromptTuningConfig(prompt_tuning_init="TEXT", tokenizer_name_or_path="t5-base") + + def test_vera_mixing_save_projection_raises(self): + # it is unclear what the right thing to do would be if some adapters save the projection weights and some don't + # so we better raise an error + + config0 = VeraConfig(target_modules="linear", init_weights=False, save_projection=True) + model = self.get_model() + model = get_peft_model(model, config0) + config1 = VeraConfig(target_modules="linear", init_weights=False, save_projection=False) + msg = re.escape( + "VeRA projection weights must be saved for all adapters or none, but got multiple different values: " + "[False, True]" + ) + with pytest.raises(ValueError, match=msg): + model.add_adapter("other", config1) diff --git a/tests/test_vera.py b/tests/test_vera.py new file mode 100644 index 0000000000..9fd3eca71c --- /dev/null +++ b/tests/test_vera.py @@ -0,0 +1,277 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This test file is for tests specific to VeRA, since VeRA has some specific challenges due to the shared weights. + +import os +import re + +import pytest +import torch +from safetensors import safe_open +from torch import nn + +from peft import PeftModel, VeraConfig, get_peft_model + + +class MLP(nn.Module): + def __init__(self, bias=True): + super().__init__() + self.relu = nn.ReLU() + self.lin0 = nn.Linear(10, 20, bias=bias) + self.lin1 = nn.Linear(20, 20, bias=bias) # lin1 and lin2 have same shape + self.lin2 = nn.Linear(20, 20, bias=bias) + self.lin3 = nn.Linear(20, 2, bias=bias) + self.sm = nn.LogSoftmax(dim=-1) + + def forward(self, X): + X = X.float() + X = self.lin0(X) + X = self.relu(X) + X = self.lin1(X) + X = self.relu(X) + X = self.lin2(X) + X = self.relu(X) + X = self.lin3(X) + X = self.sm(X) + return X + + +class TestVera: + @pytest.fixture + def mlp(self): + torch.manual_seed(0) + model = MLP() + return model + + @pytest.fixture + def mlp_same_prng(self, mlp): + torch.manual_seed(0) + + config = VeraConfig(target_modules=["lin1", "lin2"], init_weights=False) + # creates a default VeRA adapter + peft_model = get_peft_model(mlp, config) + config2 = VeraConfig(target_modules=["lin1", "lin2"], init_weights=False) + peft_model.add_adapter("other", config2) + return peft_model + + def test_multiple_adapters_same_prng_weights(self, mlp_same_prng): + # we can have multiple adapters with the same prng key, in which case the weights should be shared + assert ( + mlp_same_prng.base_model.model.lin1.vera_A["default"] + is mlp_same_prng.base_model.model.lin1.vera_A["other"] + ) + assert ( + mlp_same_prng.base_model.model.lin1.vera_B["default"] + is mlp_same_prng.base_model.model.lin1.vera_B["other"] + ) + assert ( + mlp_same_prng.base_model.model.lin2.vera_A["default"] + is mlp_same_prng.base_model.model.lin2.vera_A["other"] + ) + assert ( + mlp_same_prng.base_model.model.lin2.vera_B["default"] + is mlp_same_prng.base_model.model.lin2.vera_B["other"] + ) + + input = torch.randn(5, 10) + mlp_same_prng.set_adapter("default") + output_default = mlp_same_prng(input) + mlp_same_prng.set_adapter("other") + output_other = mlp_same_prng(input) + assert not torch.allclose(output_default, output_other, atol=1e-3, rtol=1e-3) + + def test_multiple_adapters_different_prng_raises(self): + # we cannot have multiple adapters with different prng keys + model = MLP() + config = VeraConfig(target_modules=["lin1", "lin2"], init_weights=False) + # creates a default VeRA adapter + peft_model = get_peft_model(model, config) + config2 = VeraConfig(target_modules=["lin1", "lin2"], init_weights=False, projection_prng_key=123) + + msg = ( + r"Vera PRNG initialisation key must be the same for all adapters. Got config.projection_prng_key=123 but " + r"previous config had 0" + ) + with pytest.raises(ValueError, match=msg): + peft_model.add_adapter("other", config2) + + def test_multiple_adapters_save_load_save_projection_true(self, mlp_same_prng, tmp_path): + # check saving and loading works with multiple adapters and saved projection weights + torch.manual_seed(0) + input = torch.randn(5, 10) + mlp_same_prng.set_adapter("default") + output_default = mlp_same_prng(input) + mlp_same_prng.set_adapter("other") + output_other = mlp_same_prng(input) + + # sanity check + assert not torch.allclose(output_default, output_other, atol=1e-3, rtol=1e-3) + + save_path = tmp_path / "vera" + mlp_same_prng.save_pretrained(save_path) + assert os.path.exists(save_path / "adapter_config.json") + assert os.path.exists(save_path / "other" / "adapter_config.json") + + torch.manual_seed(0) + mlp = MLP() + peft_model = PeftModel.from_pretrained(mlp, save_path) + peft_model.load_adapter(save_path / "other", "other") + + peft_model.set_adapter("default") + output_default_loaded = peft_model(input) + peft_model.set_adapter("other") + output_other_loaded = peft_model(input) + + assert torch.allclose(output_default, output_default_loaded, atol=1e-3, rtol=1e-3) + assert torch.allclose(output_other, output_other_loaded, atol=1e-3, rtol=1e-3) + + def test_multiple_adapters_save_load_save_projection_false(self, mlp, tmp_path): + # check saving and loading works with multiple adapters without saved projection weights + torch.manual_seed(1) + config = VeraConfig(target_modules=["lin1", "lin2"], init_weights=False, save_projection=False) + # creates a default VeRA adapter + peft_model = get_peft_model(mlp, config, adapter_name="first") + config2 = VeraConfig(target_modules=["lin1", "lin2"], init_weights=False, save_projection=False) + peft_model.add_adapter("second", config2) + + input = torch.randn(5, 10) + peft_model.set_adapter("first") + output_first = peft_model(input) + peft_model.set_adapter("second") + output_second = peft_model(input) + + # sanity check + assert not torch.allclose(output_first, output_second, atol=1e-3, rtol=1e-3) + + save_path = tmp_path / "vera" + peft_model.save_pretrained(save_path) + assert os.path.exists(save_path / "first" / "adapter_config.json") + assert os.path.exists(save_path / "second" / "adapter_config.json") + + torch.manual_seed(0) + mlp = MLP() + peft_model = PeftModel.from_pretrained(mlp, save_path / "first", adapter_name="first") + peft_model.load_adapter(save_path / "second", "second") + + peft_model.set_adapter("first") + output_first_loaded = peft_model(input) + peft_model.set_adapter("second") + output_second_loaded = peft_model(input) + + assert torch.allclose(output_first, output_first_loaded, atol=1e-3, rtol=1e-3) + assert torch.allclose(output_second, output_second_loaded, atol=1e-3, rtol=1e-3) + + def test_multiple_adapters_save_projection_true_contains_vera_A_vera_B(self, mlp_same_prng, tmp_path): + # check that the state_dicts don't contain the projection weights + save_path = tmp_path / "vera" + mlp_same_prng.save_pretrained(save_path) + + sd_default = {} + with safe_open(save_path / "adapter_model.safetensors", framework="pt", device="cpu") as f: + for key in f.keys(): + sd_default[key] = f.get_tensor(key) + + assert any("vera_A" in key for key in sd_default) + assert any("vera_B" in key for key in sd_default) + # default rank for VeRA is 256 + assert sd_default["base_model.vera_A"].shape == (256, 20) + assert sd_default["base_model.vera_B"].shape == (20, 256) + + sd_other = {} + with safe_open(save_path / "other" / "adapter_model.safetensors", framework="pt", device="cpu") as f: + for key in f.keys(): + sd_other[key] = f.get_tensor(key) + + assert any("vera_A" in key for key in sd_other) + assert any("vera_B" in key for key in sd_other) + assert sd_other["base_model.vera_A"].shape == (256, 20) + assert sd_other["base_model.vera_B"].shape == (20, 256) + + def test_multiple_adapters_save_projection_false_contains_no_vera_A_vera_B(self, mlp, tmp_path): + torch.manual_seed(1) + config = VeraConfig(target_modules=["lin1", "lin2"], init_weights=False, save_projection=False) + # creates a default VeRA adapter + peft_model = get_peft_model(mlp, config, adapter_name="first") + config2 = VeraConfig(target_modules=["lin1", "lin2"], init_weights=False, save_projection=False) + peft_model.add_adapter("second", config2) + + save_path = tmp_path / "vera" + peft_model.save_pretrained(save_path) + + sd_default = {} + with safe_open(save_path / "first" / "adapter_model.safetensors", framework="pt", device="cpu") as f: + for key in f.keys(): + sd_default[key] = f.get_tensor(key) + + assert not any("vera_A" in key for key in sd_default) + assert not any("vera_B" in key for key in sd_default) + + sd_other = {} + with safe_open(save_path / "second" / "adapter_model.safetensors", framework="pt", device="cpu") as f: + for key in f.keys(): + sd_other[key] = f.get_tensor(key) + + assert not any("vera_A" in key for key in sd_other) + assert not any("vera_B" in key for key in sd_other) + + def test_vera_A_vera_B_share_memory(self, mlp_same_prng): + vera_A = mlp_same_prng.vera_A["default"] + vera_B = mlp_same_prng.vera_B["default"] + + # these tensors should share the same data + assert vera_A.data_ptr() == mlp_same_prng.base_model.model.lin1.vera_A["default"].data_ptr() + assert vera_B.data_ptr() == mlp_same_prng.base_model.model.lin1.vera_B["default"].data_ptr() + assert vera_A.data_ptr() == mlp_same_prng.base_model.model.lin2.vera_A["default"].data_ptr() + assert vera_B.data_ptr() == mlp_same_prng.base_model.model.lin2.vera_B["default"].data_ptr() + # sanity check: these tensors shouldn't share the same data + assert vera_A.data_ptr() != vera_B.data_ptr() + + def test_vera_lambda_dont_share_memory(self, mlp_same_prng): + # sanity check: these tensors shouldn't share the same data + assert ( + mlp_same_prng.base_model.model.lin1.vera_lambda_b["default"].data_ptr() + != mlp_same_prng.base_model.model.lin1.vera_lambda_b["other"].data_ptr() + ) + assert ( + mlp_same_prng.base_model.model.lin1.vera_lambda_b["default"].data_ptr() + != mlp_same_prng.base_model.model.lin2.vera_lambda_b["default"].data_ptr() + ) + assert ( + mlp_same_prng.base_model.model.lin1.vera_lambda_b["other"].data_ptr() + != mlp_same_prng.base_model.model.lin2.vera_lambda_b["other"].data_ptr() + ) + assert ( + mlp_same_prng.base_model.model.lin1.vera_lambda_d["default"].data_ptr() + != mlp_same_prng.base_model.model.lin1.vera_lambda_d["other"].data_ptr() + ) + assert ( + mlp_same_prng.base_model.model.lin1.vera_lambda_d["default"].data_ptr() + != mlp_same_prng.base_model.model.lin2.vera_lambda_d["default"].data_ptr() + ) + assert ( + mlp_same_prng.base_model.model.lin1.vera_lambda_d["other"].data_ptr() + != mlp_same_prng.base_model.model.lin2.vera_lambda_d["other"].data_ptr() + ) + + def test_vera_different_shapes_raises(self, mlp): + # It is not possible (currently) to have vera_A and vera_B for different shapes, as they cannot be shared if + # their shapes are not identical. lin0 and lin1 have different shapes. + config = VeraConfig(target_modules=["lin0", "lin1"], init_weights=False) + msg = re.escape( + "Multiple target layers with different dimensions were specified. VeRA only supports a single dimension " + "size. Expected shape (20, 10), got (20, 20)." + ) + with pytest.raises(ValueError, match=msg): + get_peft_model(mlp, config) diff --git a/tests/testing_common.py b/tests/testing_common.py index 8632603083..96547bbf52 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -38,6 +38,7 @@ PromptEncoderConfig, PromptLearningConfig, PromptTuningConfig, + VeraConfig, get_peft_model, get_peft_model_state_dict, prepare_model_for_kbit_training, @@ -83,6 +84,16 @@ { "target_modules": None, }, + # VeRA + { + "r": 8, + "target_modules": None, + "vera_dropout": 0.05, + "projection_prng_key": 0xFF, + "d_initial": 0.1, + "save_projection": True, + "bias": "none", + }, ) CLASSES_MAPPING = { @@ -93,6 +104,7 @@ "prompt_tuning": (PromptTuningConfig, CONFIG_TESTING_KWARGS[4]), "adalora": (AdaLoraConfig, CONFIG_TESTING_KWARGS[5]), "boft": (BOFTConfig, CONFIG_TESTING_KWARGS[6]), + "vera": (VeraConfig, CONFIG_TESTING_KWARGS[6]), } @@ -279,6 +291,9 @@ def _test_save_pretrained(self, model_id, config_cls, config_kwargs, safe_serial if issubclass(config_cls, IA3Config): config_kwargs = config_kwargs.copy() config_kwargs["init_ia3_weights"] = False + if issubclass(config_cls, VeraConfig): + config_kwargs = config_kwargs.copy() + config_kwargs["init_weights"] = False model = self.transformers_class.from_pretrained(model_id) config = config_cls( @@ -340,9 +355,11 @@ def _test_save_pretrained_selected_adapters(self, model_id, config_cls, config_k if issubclass(config_cls, LoraConfig): config_kwargs = config_kwargs.copy() config_kwargs["init_lora_weights"] = False - if issubclass(config_cls, IA3Config): + elif issubclass(config_cls, IA3Config): config_kwargs = config_kwargs.copy() config_kwargs["init_ia3_weights"] = False + elif hasattr(config_cls, "init_weights"): + config_kwargs["init_weights"] = False model = self.transformers_class.from_pretrained(model_id) config = config_cls( @@ -460,7 +477,7 @@ def _test_merge_layers_fp16(self, model_id, config_cls, config_kwargs): _ = model.merge_and_unload() def _test_merge_layers_nan(self, model_id, config_cls, config_kwargs): - if config_cls not in (LoraConfig, IA3Config, AdaLoraConfig, LoHaConfig, LoKrConfig): + if config_cls not in (LoraConfig, IA3Config, AdaLoraConfig, LoHaConfig, LoKrConfig, VeraConfig): # Merge layers only supported for LoRA and IA³ return if ("gpt2" in model_id.lower()) and (config_cls != LoraConfig): @@ -495,7 +512,7 @@ def _test_merge_layers_nan(self, model_id, config_cls, config_kwargs): model = model.to(self.torch_device) for name, module in model.named_parameters(): - if "lora_A" in name or "ia3" in name or "lora_E" in name or "lora_B" in name: + if "lora_A" in name or "ia3" in name or "lora_E" in name or "lora_B" in name or "vera_lambda" in name: module.data[0] = torch.nan with pytest.raises( @@ -504,7 +521,7 @@ def _test_merge_layers_nan(self, model_id, config_cls, config_kwargs): model = model.merge_and_unload(safe_merge=True) for name, module in model.named_parameters(): - if "lora_A" in name or "ia3" in name or "lora_E" in name or "lora_B" in name: + if "lora_A" in name or "ia3" in name or "lora_E" in name or "lora_B" in name or "vera_lambda" in name: module.data[0] = torch.inf with pytest.raises( @@ -958,16 +975,9 @@ def _test_training_gradient_checkpointing(self, model_id, config_cls, config_kwa loss = output.sum() loss.backward() - # parameter_prefix = "ia3" if config_cls == IA3Config else "lora" - if config_cls == IA3Config: - parameter_prefix = "ia3" - elif config_cls == BOFTConfig: - parameter_prefix = "boft" - else: - parameter_prefix = "lora" for n, param in model.named_parameters(): - if parameter_prefix in n: + if model.prefix in n: assert param.grad is not None else: assert param.grad is None @@ -1018,7 +1028,15 @@ def _test_training_prompt_learning_tasks(self, model_id, config_cls, config_kwar assert param.grad is not None def _test_delete_adapter(self, model_id, config_cls, config_kwargs): - supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3, PeftType.OFT, PeftType.BOFT] + supported_peft_types = [ + PeftType.LORA, + PeftType.LOHA, + PeftType.LOKR, + PeftType.IA3, + PeftType.OFT, + PeftType.BOFT, + PeftType.VERA, + ] # IA3 does not support deleting adapters yet, but it just needs to be added # AdaLora does not support multiple adapters config = config_cls( @@ -1101,7 +1119,7 @@ def _test_unload_adapter(self, model_id, config_cls, config_kwargs): model = get_peft_model(model, config) model = model.to(self.torch_device) - if config.peft_type not in ("LORA", "ADALORA", "IA3", "BOFT"): + if config.peft_type not in ("LORA", "ADALORA", "IA3", "BOFT", "VERA"): with pytest.raises(AttributeError): model = model.unload() else: