Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pluggable Model Integration Interface #738

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
"Seq2SeqLMHead",
"TaggingHead",
],
"interface": ["AdapterModelInterface"],
"methods.adapter_layer_base": ["AdapterLayerBase", "ComposableAdapterLayerBase"],
"model_mixin": [
"EmbeddingAdaptersMixin",
Expand Down Expand Up @@ -196,6 +197,7 @@
Seq2SeqLMHead,
TaggingHead,
)
from .interface import AdapterModelInterface
from .methods.adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase
from .model_mixin import (
EmbeddingAdaptersMixin,
Expand Down
48 changes: 48 additions & 0 deletions src/adapters/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from dataclasses import dataclass
from typing import List


class AdapterType:
"""
Enum for the different adapter types.
"""

bottleneck = "bottleneck"
prefix_tuning = "prefix_tuning"
lora = "lora"
prompt_tuning = "prompt_tuning"
reft = "reft"


@dataclass
class AdapterModelInterface:
"""
Defines the main interface for integrating adapter methods into a model class.
This interface translates generic accessor names to model-specific attribute names.

Args:
adapter_types (List[str]): List of adapter types that are supported by the model.
model_embeddings (str): Name of the model's embedding layer.
model_layers (str): Name of the model's layer list.
layer_self_attn (str): Name of the self-attention layer in a transformer layer.
layer_cross_attn (str): Name of the cross-attention layer in a transformer layer.
attn_k_proj (str): Name of the key projection layer in an attention layer.
attn_q_proj (str): Name of the query projection layer in an attention layer.
attn_v_proj (str): Name of the value projection layer in an attention layer
layer_intermediate_proj (str): Name of the intermediate projection layer in a transformer layer.
layer_output_proj (str): Name of the output projection layer in a transformer layer.
"""

adapter_types: List[str]

model_embeddings: str
model_layers: str

layer_self_attn: str
layer_cross_attn: str
attn_k_proj: str
attn_q_proj: str
attn_v_proj: str

layer_intermediate_proj: str
layer_output_proj: str
8 changes: 8 additions & 0 deletions src/adapters/methods/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .lora import init_lora
from .reft import init_reft


METHOD_INIT_MAPPING = {
"lora": init_lora,
"reft": init_reft,
}
22 changes: 22 additions & 0 deletions src/adapters/methods/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from ..composition import Average, BatchSplit, Parallel, Stack
from ..configuration import LoRAConfig, ModelAdaptersConfig
from ..utils import multigetattr, multisetattr
from .adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase
from .utils import dequantize_bnb_weight

Expand Down Expand Up @@ -803,3 +804,24 @@ def T(w):
raise ValueError(f"Invalid adapter setup. Cannot use {adapter_setup} with LoRA.")

return F.linear(x, T(self.weight), bias=self.bias)


def init_lora(model):
for _, _, attention in model.iter_attentions():
if q_proj := getattr(attention, model.adapter_interface.attn_q_proj, None):
lora_proj = LoRALinear.wrap(q_proj, "selfattn", model.config, model.adapters_config, attn_key="q")
setattr(attention, model.adapter_interface.attn_q_proj, lora_proj)
if k_proj := getattr(attention, model.adapter_interface.attn_k_proj, None):
lora_proj = LoRALinear.wrap(k_proj, "selfattn", model.config, model.adapters_config, attn_key="k")
setattr(attention, model.adapter_interface.attn_k_proj, lora_proj)
if v_proj := getattr(attention, model.adapter_interface.attn_v_proj, None):
lora_proj = LoRALinear.wrap(v_proj, "selfattn", model.config, model.adapters_config, attn_key="v")
setattr(attention, model.adapter_interface.attn_v_proj, lora_proj)

for _, layer in model.iter_layers():
if intermediate_proj := multigetattr(layer, model.adapter_interface.layer_intermediate_proj):
lora_proj = LoRALinear.wrap(intermediate_proj, "intermediate", model.config, model.adapters_config)
multisetattr(layer, model.adapter_interface.layer_intermediate_proj, lora_proj)
if output_proj := multigetattr(layer, model.adapter_interface.layer_output_proj):
lora_proj = LoRALinear.wrap(output_proj, "output", model.config, model.adapters_config)
multisetattr(layer, model.adapter_interface.layer_output_proj, lora_proj)
50 changes: 45 additions & 5 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import defaultdict
from copy import deepcopy
from os.path import join
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union

import torch
from torch import nn
Expand All @@ -19,15 +19,25 @@
from .configuration import ADAPTER_CONFIG_MAP, AdapterConfig, AdapterFusionConfig, BnConfig
from .context import AdapterSetup, ForwardContext
from .hub_mixin import PushAdapterToHubMixin
from .interface import AdapterModelInterface
from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader
from .methods import METHOD_INIT_MAPPING
from .methods.adapter_layer_base import AdapterLayerBase
from .methods.bottleneck import BottleneckLayer
from .methods.lora import LoRALayer
from .methods.modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock, init_shared_parameters
from .methods.prefix_tuning import PrefixTuningLayer, PrefixTuningPool
from .methods.prompt_tuning import PromptTuningLayer
from .methods.reft import init_reft
from .utils import EMBEDDING_FILE, TOKENIZER_PATH, get_adapter_config_hash, inherit_doc, patch_forward
from .utils import (
EMBEDDING_FILE,
TOKENIZER_PATH,
get_adapter_config_hash,
inherit_doc,
multigetattr,
multihasattr,
patch_forward,
)
from .wrappers.configuration import SUBMODEL_NAMES, init_adapters_config


Expand Down Expand Up @@ -418,9 +428,6 @@ def _init_adapters_submodules(self, model_config, adapters_config):
if hasattr(module, "init_adapters"):
module.init_adapters(model_config, adapters_config)

# Initialize reft modules
init_reft(self)

def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=True):
"""
This method initializes adapter modules and fusion modules from the model config.
Expand All @@ -429,6 +436,15 @@ def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=Tr

# Initialize adapters config
init_adapters_config(self, model_config, adapters_config)

# Initialize adapter types defined in interface
if self.base_model.adapter_interface is not None:
for adapter_type in self.base_model.adapter_interface.adapter_types:
init_func = METHOD_INIT_MAPPING[adapter_type]
init_func(self.base_model)
else:
init_reft(self.base_model)

# Initialize adapters in all submodules
self._init_adapters_submodules(self.config, self.adapters_config)

Expand Down Expand Up @@ -1450,13 +1466,37 @@ def save_pretrained(

@inherit_doc
class ModelBaseAdaptersMixin(ModelAdaptersMixin):
adapter_interface: AdapterModelInterface = None
add_base_adapters = True

def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=True):
super().init_adapters(model_config, adapters_config, add_prefix_tuning_pool)

patch_forward(self)

# Adapter Interface Methods

def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(multigetattr(self, self.adapter_interface.model_layers)):
yield i, layer

def get_layer(self, idx: int) -> nn.Module:
return multigetattr(self, self.adapter_interface.model_layers)[idx]

def iter_attentions(self) -> Iterable[Tuple[int, Literal["self", "cross"], nn.Module]]:
for i, layer in self.iter_layers():
if multihasattr(layer, self.adapter_interface.layer_self_attn or ""):
yield i, "self", multigetattr(layer, self.adapter_interface.layer_self_attn)
if multihasattr(layer, self.adapter_interface.layer_cross_attn or ""):
yield i, "cross", multigetattr(layer, self.adapter_interface.layer_cross_attn)

def iter_layer_ffns(self) -> Iterable[Tuple[int, Literal["intermediate", "output"], nn.Module]]:
for i, layer in self.iter_layers():
if intermediate_proj := multigetattr(layer, self.adapter_interface.layer_intermediate_proj):
yield i, "intermediate", intermediate_proj
if output_proj := multigetattr(layer, self.adapter_interface.layer_output_proj):
yield i, "output", output_proj

def post_embedding_forward(self, module, args, embedding_output):
if isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin):
embedding_output = self.invertible_adapters_forward(embedding_output)
Expand Down
29 changes: 29 additions & 0 deletions src/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,35 @@ def inherit_doc(cls):
return cls


def multigetattr(o: object, name: str, default=None) -> Optional[object]:
for n in name.split("."):
if hasattr(o, n):
o = getattr(o, n)
else:
return default
return o


def multihasattr(o: object, name: str) -> bool:
parts = name.split(".")
for n in parts:
if hasattr(o, n):
o = getattr(o, n)
else:
return False
return True


def multisetattr(o: object, name: str, value: object):
parts = name.split(".")
for n in parts[:-1]:
if hasattr(o, n):
o = getattr(o, n)
else:
return
setattr(o, parts[-1], value)


def urljoin(*args):
return "/".join([s.strip("/") for s in args])

Expand Down
2 changes: 2 additions & 0 deletions src/adapters/wrappers/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def init_adapters_config(
model.adapters_config = ModelAdaptersConfig()
elif model_config.adapters is not None and not isinstance(model_config.adapters, ModelAdaptersConfig):
model.adapters_config = ModelAdaptersConfig(**model_config.adapters)
if hasattr(model, "base_model") and model.base_model is not model:
model.base_model.adapters_config = model.adapters_config

# Convert AdapterFusions from old format for backwards compatibility
fusion_models = getattr(model_config, "adapter_fusion_models", [])
Expand Down
64 changes: 43 additions & 21 deletions src/adapters/wrappers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from transformers.models.auto.configuration_auto import model_type_to_module_name

from ..configuration import ModelAdaptersConfig
from ..interface import AdapterModelInterface
from ..model_mixin import (
EmbeddingAdaptersWrapperMixin,
ModelAdaptersMixin,
ModelBaseAdaptersMixin,
ModelUsingSubmodelsAdaptersMixin,
ModelWithHeadsAdaptersMixin,
)
Expand Down Expand Up @@ -48,30 +50,50 @@ def replace_with_adapter_class(module: nn.Module, modules_with_adapters) -> None
pass


def init(model: PreTrainedModel, adapters_config: Optional[ModelAdaptersConfig] = None) -> None:
def init(
model: PreTrainedModel,
adapters_config: Optional[ModelAdaptersConfig] = None,
interface: Optional[AdapterModelInterface] = None,
) -> None:
if isinstance(model, ModelAdaptersMixin):
return model

# First, replace original module classes with their adapters counterparts
model_name = get_module_name(model.config.model_type)
modules_with_adapters = importlib.import_module(f".{model_name}.modeling_{model_name}", "adapters.models")
submodules = list(model.modules())

# Replace the base model class
replace_with_adapter_class(submodules.pop(0), modules_with_adapters)

# Check if the base model class derives from ModelUsingSubmodelsAdaptersMixin
if isinstance(model, ModelUsingSubmodelsAdaptersMixin):
# Before initializing the submodels, make sure that adapters_config is set for the whole model.
# Otherwise, it would not be shared between the submodels.
init_adapters_config(model, model.config, adapters_config)
adapters_config = model.adapters_config
model.init_submodels()
submodules = []

# Change the class of all child modules to their adapters class
for module in submodules:
replace_with_adapter_class(module, modules_with_adapters)
if interface is not None:
base_model = model.base_model
model_class_name = base_model.__class__.__name__
model_class = type(
model_class_name,
(EmbeddingAdaptersWrapperMixin, ModelBaseAdaptersMixin, base_model.__class__),
{},
)
base_model.__class__ = model_class
base_model.adapter_interface = interface
else:
# First, replace original module classes with their adapters counterparts
model_name = get_module_name(model.config.model_type)
try:
modules_with_adapters = importlib.import_module(f".{model_name}.modeling_{model_name}", "adapters.models")
except ImportError:
raise ValueError(
f"Model {model_name} not pre-supported by adapters. Please specify and pass `interface` explicitly."
)
submodules = list(model.modules())

# Replace the base model class
replace_with_adapter_class(submodules.pop(0), modules_with_adapters)

# Check if the base model class derives from ModelUsingSubmodelsAdaptersMixin
if isinstance(model, ModelUsingSubmodelsAdaptersMixin):
# Before initializing the submodels, make sure that adapters_config is set for the whole model.
# Otherwise, it would not be shared between the submodels.
init_adapters_config(model, model.config, adapters_config)
adapters_config = model.adapters_config
model.init_submodels()
submodules = []

# Change the class of all child modules to their adapters class
for module in submodules:
replace_with_adapter_class(module, modules_with_adapters)

# Next, check if model class itself is not replaced and has an adapter-supporting base class
if not isinstance(model, ModelAdaptersMixin):
Expand Down
9 changes: 6 additions & 3 deletions tests/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from transformers.testing_utils import require_torch, torch_device


def create_twin_models(model_class, config_creator=None):
def create_twin_models(model_class, config_creator=None, interface=None):
if config_creator and model_class.__name__.startswith("Auto"):
model_config = config_creator()
model1 = model_class.from_config(model_config)
Expand All @@ -23,7 +23,7 @@ def create_twin_models(model_class, config_creator=None):
else:
model_config = model_class.config_class()
model1 = model_class(model_config)
adapters.init(model1)
adapters.init(model1, interface=interface)
model1.eval()
# create a twin initialized with the same random weights
model2 = copy.deepcopy(model1)
Expand Down Expand Up @@ -186,8 +186,11 @@ def run_forward_test(self, model, adapter_config, dtype=torch.float32):
self.assertGreaterEqual(len(output_1), len(base_output))
self.assertFalse(torch.equal(output_1[0], base_output[0]))

def create_twin_models(self):
return create_twin_models(self.model_class, self.config)

def run_load_test(self, adapter_config):
model1, model2 = create_twin_models(self.model_class, self.config)
model1, model2 = self.create_twin_models()

name = "dummy_adapter"
model1.add_adapter(name, config=adapter_config)
Expand Down
Loading
Loading