Skip to content

Commit

Permalink
Move InvertibleAdaptersMixin to separate module
Browse files Browse the repository at this point in the history
(fixes docs building issue)
  • Loading branch information
calpt committed Dec 15, 2021
1 parent 2ae7845 commit 6f1713a
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 78 deletions.
10 changes: 3 additions & 7 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,7 @@
"StaticAdapterFusionConfig",
]
_import_structure["adapters.heads"] = ["ModelWithFlexibleHeadsAdaptersMixin"]
_import_structure["adapters.invertible_adapters_mixin"] = ["InvertibleAdaptersMixin"]
_import_structure["adapters.layer"] = ["AdapterLayerBaseMixin"]
_import_structure["adapters.loading"] = [
"AdapterFusionLoader",
Expand All @@ -1393,7 +1394,6 @@
"WeightsLoaderHelper",
]
_import_structure["adapters.model_mixin"] = [
"InvertibleAdaptersMixin",
"ModelAdaptersMixin",
"ModelConfigAdaptersMixin",
"ModelWithHeadsAdaptersMixin",
Expand Down Expand Up @@ -3170,6 +3170,7 @@
StaticAdapterFusionConfig,
)
from .adapters.heads import ModelWithFlexibleHeadsAdaptersMixin
from .adapters.invertible_adapters_mixin import InvertibleAdaptersMixin
from .adapters.layer import AdapterLayerBaseMixin
from .adapters.loading import (
AdapterFusionLoader,
Expand All @@ -3178,12 +3179,7 @@
WeightsLoader,
WeightsLoaderHelper,
)
from .adapters.model_mixin import (
InvertibleAdaptersMixin,
ModelAdaptersMixin,
ModelConfigAdaptersMixin,
ModelWithHeadsAdaptersMixin,
)
from .adapters.model_mixin import ModelAdaptersMixin, ModelConfigAdaptersMixin, ModelWithHeadsAdaptersMixin
from .adapters.trainer import AdapterTrainer, Seq2SeqAdapterTrainer
from .adapters.training import AdapterArguments, MultiLingAdapterArguments
from .adapters.utils import (
Expand Down
67 changes: 67 additions & 0 deletions src/transformers/adapters/invertible_adapters_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from torch import nn

from .modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock


class InvertibleAdaptersMixin:
"""Mixin for Transformer models adding invertible adapters."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.invertible_adapters = nn.ModuleDict(dict())

def add_invertible_adapter(self, adapter_name: str):
"""
Adds an invertible adapter module for the adapter with the given name. If the given adapter does not specify an
invertible adapter config, this method does nothing.
Args:
adapter_name (str): The name of the adapter for which to add an invertible adapter module.
"""
if adapter_name in self.invertible_adapters:
raise ValueError(f"Model already contains an adapter module for '{adapter_name}'.")
adapter_config = self.config.adapters.get(adapter_name)
if adapter_config and adapter_config["inv_adapter"]:
if adapter_config["inv_adapter"] == "nice":
inv_adap = NICECouplingBlock(
[[self.config.hidden_size]],
non_linearity=adapter_config["non_linearity"],
reduction_factor=adapter_config["inv_adapter_reduction_factor"],
)
elif adapter_config["inv_adapter"] == "glow":
inv_adap = GLOWCouplingBlock(
[[self.config.hidden_size]],
non_linearity=adapter_config["non_linearity"],
reduction_factor=adapter_config["inv_adapter_reduction_factor"],
)
else:
raise ValueError(f"Invalid invertible adapter type '{adapter_config['inv_adapter']}'.")
self.invertible_adapters[adapter_name] = inv_adap
self.invertible_adapters[adapter_name].apply(Adapter.init_bert_weights)

def delete_invertible_adapter(self, adapter_name: str):
if adapter_name in self.invertible_adapters:
del self.invertible_adapters[adapter_name]

def get_invertible_adapter(self):
# TODO: Currently no fusion over invertible adapters, takes only very first language adapter position
if self.config.adapters.active_setup is not None and len(self.config.adapters.active_setup) > 0:
first_adapter = self.config.adapters.active_setup.first()
if first_adapter in self.invertible_adapters:
return self.invertible_adapters[first_adapter]
return None

def enable_invertible_adapters(self, adapter_names):
for adapter_name in adapter_names:
if adapter_name in self.invertible_adapters:
for param in self.invertible_adapters[adapter_name].parameters():
param.requires_grad = True

def invertible_adapters_forward(self, hidden_states, rev=False):
# TODO: Currently no fusion over invertible adapters, takes only very first language adapter position
if self.config.adapters.active_setup is not None and len(self.config.adapters.active_setup) > 0:
first_adapter = self.config.adapters.active_setup.first()
if first_adapter in self.invertible_adapters:
hidden_states = self.invertible_adapters[first_adapter](hidden_states, rev=rev)

return hidden_states
65 changes: 0 additions & 65 deletions src/transformers/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,77 +12,12 @@
from .configuration import AdapterConfig, AdapterFusionConfig, ModelAdaptersConfig, get_adapter_config_hash
from .hub_mixin import PushAdapterToHubMixin
from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader
from .modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock
from .utils import EMBEDDING_FILE, TOKENIZER_PATH, inherit_doc


logger = logging.getLogger(__name__)


class InvertibleAdaptersMixin:
"""Mixin for Transformer models adding invertible adapters."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.invertible_adapters = nn.ModuleDict(dict())

def add_invertible_adapter(self, adapter_name: str):
"""
Adds an invertible adapter module for the adapter with the given name. If the given adapter does not specify an
invertible adapter config, this method does nothing.
Args:
adapter_name (str): The name of the adapter for which to add an invertible adapter module.
"""
if adapter_name in self.invertible_adapters:
raise ValueError(f"Model already contains an adapter module for '{adapter_name}'.")
adapter_config = self.config.adapters.get(adapter_name)
if adapter_config and adapter_config["inv_adapter"]:
if adapter_config["inv_adapter"] == "nice":
inv_adap = NICECouplingBlock(
[[self.config.hidden_size]],
non_linearity=adapter_config["non_linearity"],
reduction_factor=adapter_config["inv_adapter_reduction_factor"],
)
elif adapter_config["inv_adapter"] == "glow":
inv_adap = GLOWCouplingBlock(
[[self.config.hidden_size]],
non_linearity=adapter_config["non_linearity"],
reduction_factor=adapter_config["inv_adapter_reduction_factor"],
)
else:
raise ValueError(f"Invalid invertible adapter type '{adapter_config['inv_adapter']}'.")
self.invertible_adapters[adapter_name] = inv_adap
self.invertible_adapters[adapter_name].apply(Adapter.init_bert_weights)

def delete_invertible_adapter(self, adapter_name: str):
if adapter_name in self.invertible_adapters:
del self.invertible_adapters[adapter_name]

def get_invertible_adapter(self):
# TODO: Currently no fusion over invertible adapters, takes only very first language adapter position
if self.config.adapters.active_setup is not None and len(self.config.adapters.active_setup) > 0:
first_adapter = self.config.adapters.active_setup.first()
if first_adapter in self.invertible_adapters:
return self.invertible_adapters[first_adapter]
return None

def enable_invertible_adapters(self, adapter_names):
for adapter_name in adapter_names:
if adapter_name in self.invertible_adapters:
for param in self.invertible_adapters[adapter_name].parameters():
param.requires_grad = True

def invertible_adapters_forward(self, hidden_states, rev=False):
# TODO: Currently no fusion over invertible adapters, takes only very first language adapter position
if self.config.adapters.active_setup is not None and len(self.config.adapters.active_setup) > 0:
first_adapter = self.config.adapters.active_setup.first()
if first_adapter in self.invertible_adapters:
hidden_states = self.invertible_adapters[first_adapter](hidden_states, rev=rev)

return hidden_states


class ModelConfigAdaptersMixin(ABC):
"""
Mixin for model config classes, adding support for adapters.
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/adapters/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
QuestionAnsweringHead,
TaggingHead,
)
from ..invertible_adapters_mixin import InvertibleAdaptersMixin
from ..layer import AdapterLayerBaseMixin
from ..model_mixin import InvertibleAdaptersMixin, ModelAdaptersMixin
from ..model_mixin import ModelAdaptersMixin


logger = logging.getLogger(__name__)
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/adapters/models/distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from torch import nn

from ..composition import AdapterCompositionBlock, parse_composition
from ..model_mixin import InvertibleAdaptersMixin, ModelAdaptersMixin
from ..invertible_adapters_mixin import InvertibleAdaptersMixin
from ..model_mixin import ModelAdaptersMixin
from .bert import BertEncoderAdaptersMixin, BertModelHeadsMixin, BertOutputAdaptersMixin, BertSelfOutputAdaptersMixin


Expand Down
3 changes: 2 additions & 1 deletion src/transformers/adapters/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from ..composition import AdapterCompositionBlock, parse_composition
from ..heads import CausalLMHead, ClassificationHead, MultiLabelClassificationHead, TaggingHead
from ..model_mixin import InvertibleAdaptersMixin, ModelAdaptersMixin
from ..invertible_adapters_mixin import InvertibleAdaptersMixin
from ..model_mixin import ModelAdaptersMixin
from .bert import (
BertEncoderAdaptersMixin,
BertOutputAdaptersMixin,
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from torch.nn import CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...adapters.model_mixin import InvertibleAdaptersMixin, ModelWithHeadsAdaptersMixin
from ...adapters.invertible_adapters_mixin import InvertibleAdaptersMixin
from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin
from ...adapters.models.bart import (
BartDecoderLayerAdaptersMixin,
BartEncoderDecoderAdaptersMixin,
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from torch.nn import CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...adapters.model_mixin import InvertibleAdaptersMixin, ModelWithHeadsAdaptersMixin
from ...adapters.invertible_adapters_mixin import InvertibleAdaptersMixin
from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin
from ...adapters.models.bart import (
BartDecoderLayerAdaptersMixin,
BartEncoderDecoderAdaptersMixin,
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from torch.utils.checkpoint import checkpoint

from ...activations import ACT2FN
from ...adapters.model_mixin import InvertibleAdaptersMixin, ModelWithHeadsAdaptersMixin
from ...adapters.invertible_adapters_mixin import InvertibleAdaptersMixin
from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin
from ...adapters.models.t5 import (
T5BlockAdaptersMixin,
T5CrossAttentionLayerAdaptersMixin,
Expand Down

0 comments on commit 6f1713a

Please sign in to comment.