Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compatibility of adapters with HF Accelerate auto device-mapping #678

Merged
merged 2 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/contributing/adding_adapters_to_a_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Now that we have discussed the purpose of every file in `src/adapters/models/<mo
- Create a new class in `src/adapters/models/<model_type>/modeling_<model_type>.py` with the name `<class>WithAdapters`. This class should derive from the corresponding mixin and HF class.
- Copy the function you want to change into this class and modify it.
- e.g., the `forward` method of the `BertSelfAttention` class must be adapted to support prefix tuning. We therefore create a class `BertSelfAttentionWithAdapters(BertSelfAttentionAdaptersMixin, BertSelfAttention)`, copy the forward method into it and modify it.
- if the `forward` method of a module is copied and modified, make sure to call `adapters.utils.patch_forward()` in the module's `init_adapters()` method. This ensures adapters work correctly with the `accelerate` package.
4. **Modify MODEL_MIXIN_MAPPING**
- For each mixin whose class was not copied into `modeling_<model_type>.py`, add the mixin/class combination into `MODEL_MIXIN_MAPPING` in the file `src/adapters/models/__init__.py`.
5. **Create the adapter model:**
Expand Down
7 changes: 6 additions & 1 deletion src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .methods.modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock, init_shared_parameters
from .methods.prefix_tuning import PrefixTuningLayer, PrefixTuningPool
from .methods.prompt_tuning import PromptTuningLayer
from .utils import EMBEDDING_FILE, TOKENIZER_PATH, get_adapter_config_hash, inherit_doc
from .utils import EMBEDDING_FILE, TOKENIZER_PATH, get_adapter_config_hash, inherit_doc, patch_forward
from .wrappers.configuration import SUBMODEL_NAMES, init_adapters_config


Expand Down Expand Up @@ -1258,6 +1258,11 @@ def save_pretrained(
class ModelBaseAdaptersMixin(ModelAdaptersMixin):
add_base_adapters = True

def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=True):
super().init_adapters(model_config, adapters_config, add_prefix_tuning_pool)

patch_forward(self)

def post_embedding_forward(self, module, args, embedding_output):
if isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin):
embedding_output = self.invertible_adapters_forward(embedding_output)
Expand Down
4 changes: 4 additions & 0 deletions src/adapters/models/albert/mixin_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin
from ...utils import patch_forward


class AlbertAttentionAdaptersMixin:
Expand All @@ -23,6 +24,7 @@ def init_adapters(self, model_config, adapters_config):
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)
patch_forward(self)


class AlbertEncoderLayerAdaptersMixin:
Expand All @@ -40,6 +42,8 @@ def init_adapters(self, model_config, adapters_config):

self.attention.location_key = "self"

patch_forward(self)


class AlbertModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin):
"""Adds adapters to the AlbertModel module."""
Expand Down
7 changes: 7 additions & 0 deletions src/adapters/models/bart/mixin_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
InvertibleAdaptersWrapperMixin,
ModelBaseAdaptersMixin,
)
from ...utils import patch_forward


class BartAttentionAdaptersMixin:
Expand All @@ -28,6 +29,7 @@ def init_adapters(self, model_config, adapters_config):
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)
patch_forward(self)


class BartEncoderLayerAdaptersMixin:
Expand All @@ -44,6 +46,8 @@ def init_adapters(self, model_config, adapters_config):
self.attention_adapters = BottleneckLayer("mh_adapter")
self.output_adapters = BottleneckLayer("output_adapter")

patch_forward(self)


class BartDecoderLayerAdaptersMixin(BartEncoderLayerAdaptersMixin):
"""Adds adapters to the BartDecoderLayer module of BART."""
Expand All @@ -65,6 +69,9 @@ class BartEncoderAdaptersMixin(InvertibleAdaptersMixin):
class BartDecoderAdaptersMixin:
"""Adds adapters to the BartDecoder module of BART."""

def init_adapters(self, model_config, adapters_config):
patch_forward(self)

def forward(
self, input_ids: torch.LongTensor = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, **kwargs
):
Expand Down
3 changes: 3 additions & 0 deletions src/adapters/models/beit/mixin_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import ModelBaseAdaptersMixin
from ...utils import patch_forward


class BeitSelfAttentionAdaptersMixin:
Expand All @@ -20,6 +21,7 @@ def init_adapters(self, model_config, adapters_config):
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)
patch_forward(self)


class BeitIntermediateAdaptersMixin:
Expand All @@ -40,6 +42,7 @@ class BeitLayerAdaptersMixin:
def init_adapters(self, model_config, adapters_config):
self.attention_adapters = BottleneckLayer("mh_adapter")
self.output_adapters = BottleneckLayer("output_adapter")
patch_forward(self)


class BeitModelAdaptersMixin(ModelBaseAdaptersMixin):
Expand Down
4 changes: 4 additions & 0 deletions src/adapters/models/bert/mixin_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin
from ...utils import patch_forward


logger = logging.getLogger(__name__)
Expand All @@ -25,6 +26,7 @@ def init_adapters(self, model_config, adapters_config):
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)
patch_forward(self)


# For backwards compatibility, BertSelfOutput inherits directly from BottleneckLayer
Expand All @@ -37,6 +39,7 @@ def __init__(self):
def init_adapters(self, model_config, adapters_config):
self.location_key = "mh_adapter"
super().init_adapters(model_config, adapters_config)
patch_forward(self)


# For backwards compatibility, BertOutput inherits directly from BottleneckLayer
Expand All @@ -49,6 +52,7 @@ def __init__(self):
def init_adapters(self, model_config, adapters_config):
self.location_key = "output_adapter"
super().init_adapters(model_config, adapters_config)
patch_forward(self)


class BertLayerAdaptersMixin:
Expand Down
4 changes: 4 additions & 0 deletions src/adapters/models/clip/mixin_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
InvertibleAdaptersWrapperMixin,
ModelBaseAdaptersMixin,
)
from ...utils import patch_forward


class CLIPAttentionAdaptersMixin:
Expand All @@ -27,6 +28,7 @@ def init_adapters(self, model_config, adapters_config):
self.prefix_tuning = PrefixTuningLayer(
"self_prefix", model_config, adapters_config, add_model_type_to_key=True
)
patch_forward(self)


class CLIPEncoderLayerAdaptersMixin:
Expand All @@ -40,6 +42,8 @@ def init_adapters(self, model_config, adapters_config):
self.attention_adapters = BottleneckLayer("mh_adapter")
self.output_adapters = BottleneckLayer("output_adapter")

patch_forward(self)


class CLIPEncoderAdaptersMixin:
"""Adds adapters to the CLIPEncoder module of CLIP."""
Expand Down
2 changes: 2 additions & 0 deletions src/adapters/models/deberta/mixin_deberta.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ...methods.lora import LoRAMergedLinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...utils import patch_forward


class DebertaSelfAttentionAdaptersMixin:
Expand All @@ -12,3 +13,4 @@ def init_adapters(self, model_config, adapters_config):
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)
patch_forward(self)
2 changes: 2 additions & 0 deletions src/adapters/models/deberta_v2/mixin_deberta_v2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...utils import patch_forward


class DebertaV2SelfAttentionAdaptersMixin:
Expand All @@ -14,3 +15,4 @@ def init_adapters(self, model_config, adapters_config):
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)
patch_forward(self)
7 changes: 7 additions & 0 deletions src/adapters/models/distilbert/mixin_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin
from ...utils import patch_forward


class DistilBertMultiHeadSelfAttentionMixin:
Expand All @@ -18,6 +19,7 @@ def init_adapters(self, model_config, adapters_config):
self.v_lin = LoRALinear.wrap(self.v_lin, "selfattn", model_config, adapters_config, attn_key="v")

self.prefix_tuning = PrefixTuningLayer("self", model_config, adapters_config)
patch_forward(self)


class DistilBertTransfomerBlockAdaptersMixin:
Expand All @@ -31,10 +33,15 @@ def init_adapters(self, model_config, adapters_config):
self.attention_adapters = BottleneckLayer("mh_adapter")
self.output_adapters = BottleneckLayer("output_adapter")

patch_forward(self)


class DistilBertTransformerAdaptersMixin:
"""Adds adapters to the Transformer module of DistilBert."""

def init_adapters(self, model_config, adapters_config):
patch_forward(self)

def forward(self, *args, **kwargs):
if hasattr(self, "pre_forward_fn"):
kwargs["x"] = self.pre_forward_fn(self, kwargs["x"])
Expand Down
5 changes: 5 additions & 0 deletions src/adapters/models/gpt2/mixin_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ...methods.lora import LoRALinear, LoRAMergedLinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin
from ...utils import patch_forward


class GPT2AttentionAdaptersMixin:
Expand All @@ -26,6 +27,8 @@ def init_adapters(self, model_config, adapters_config):
location_key = "cross_prefix" if self.is_cross_attention else "self_prefix"
self.prefix_tuning = PrefixTuningLayer(location_key, model_config, adapters_config)

patch_forward(self)


class GPT2DecoderBlockAdaptersMixin:
"""Adds adapters to the TransformerBlock module of DistilBert."""
Expand All @@ -52,6 +55,8 @@ def init_adapters(self, model_config, adapters_config):
self.attention_adapters = BottleneckLayer("mh_adapter")
self.output_adapters = BottleneckLayer("output_adapter")

patch_forward(self)


class GPT2ModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin):
support_prompt_tuning = False
Expand Down
4 changes: 4 additions & 0 deletions src/adapters/models/gptj/mixin_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin
from ...utils import patch_forward


class GPTJAttentionAdaptersMixin:
Expand All @@ -20,6 +21,7 @@ def init_adapters(self, model_config, adapters_config):
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)
patch_forward(self)


class GPTJMLPAdaptersMixin:
Expand All @@ -36,6 +38,8 @@ def init_adapters(self, model_config, adapters_config):
self.attention_adapters = BottleneckLayer("mh_adapter")
self.output_adapters = BottleneckLayer("output_adapter")

patch_forward(self)


class GPTJModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin):
support_prompt_tuning = False
Expand Down
5 changes: 5 additions & 0 deletions src/adapters/models/llama/mixin_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin
from ...utils import patch_forward


class LlamaAttentionMixin:
Expand All @@ -16,6 +17,8 @@ def init_adapters(self, model_config, adapters_config):

self.prefix_tuning = PrefixTuningLayer("self_prefix", model_config, adapters_config)

patch_forward(self)


class LlamaDecoderLayerMixin:
def init_adapters(self, model_config, adapters_config):
Expand All @@ -26,6 +29,8 @@ def init_adapters(self, model_config, adapters_config):
self.attention_adapters = BottleneckLayer("mh_adapter")
self.output_adapters = BottleneckLayer("output_adapter")

patch_forward(self)


class LlamaModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin):
support_prompt_tuning = False
Expand Down
5 changes: 5 additions & 0 deletions src/adapters/models/t5/mixin_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ModelBaseAdaptersMixin,
ModelWithHeadsAdaptersMixin,
)
from ...utils import patch_forward


class T5AttentionAdaptersMixin:
Expand All @@ -27,6 +28,7 @@ def init_adapters(self, model_config, adapters_config):
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)
patch_forward(self)


class T5SelfAttentionLayerAdaptersMixin(BottleneckLayer):
Expand All @@ -36,6 +38,7 @@ def __init__(self):
def init_adapters(self, model_config, adapters_config):
self.location_key = "mh_adapter"
super().init_adapters(model_config, adapters_config)
patch_forward(self)


class T5CrossAttentionLayerAdaptersMixin(BottleneckLayer):
Expand All @@ -46,6 +49,7 @@ def init_adapters(self, model_config, adapters_config):
self.location_key = "cross_adapter"
self.EncDecAttention.location_key = "cross"
super().init_adapters(model_config, adapters_config)
patch_forward(self)


class T5FFLayerAdaptersMixin(BottleneckLayer):
Expand Down Expand Up @@ -82,6 +86,7 @@ class T5StackAdaptersMixin(InvertibleAdaptersMixin):
def init_adapters(self, model_config, adapters_config):
if not self.is_decoder:
InvertibleAdaptersMixin.init_adapters(self, self.config, adapters_config)
patch_forward(self)

def post_embedding_forward(self, embedding_output):
embedding_output = self.invertible_adapters_forward(embedding_output)
Expand Down
5 changes: 5 additions & 0 deletions src/adapters/models/vit/mixin_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import ModelBaseAdaptersMixin
from ...utils import patch_forward


class ViTSelfAttentionAdaptersMixin:
Expand All @@ -20,6 +21,7 @@ def init_adapters(self, model_config, adapters_config):
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)
patch_forward(self)


class ViTIntermediateAdaptersMixin:
Expand All @@ -37,13 +39,16 @@ def init_adapters(self, model_config, adapters_config):
# Wrap layers for LoRA
self.dense = LoRALinear.wrap(self.dense, "output", model_config, adapters_config)

patch_forward(self)


# Unlike BERT, self attention adapters are added to Layer module in ViT
class ViTLayerAdaptersMixin:
"""Adds adapters to the ViTSelfOutput module."""

def init_adapters(self, model_config, adapters_config):
self.attention_adapters = BottleneckLayer("mh_adapter")
patch_forward(self)


class ViTModelAdaptersMixin(ModelBaseAdaptersMixin):
Expand Down
9 changes: 9 additions & 0 deletions src/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,3 +862,12 @@ def prefix_attention_mask(attention_mask, dim: int = 3, prefix_value: int = 0):
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=dim)

return attention_mask


def patch_forward(module: torch.nn.Module):
# HF Accelerate's `add_hook_to_module()` replaces the module forward method with a wrapper
# and stores the original forward method in `_old_forward`. For this to work with Adapters' post-hook wrapping,
# we need to explicitly set to potentially overriden forward methods on adapter init.
# The `add_hook_to_module()` method is e.g. used for `device_map="auto"` in the `PreTrainedModel.from_pretrained()` method.
if hasattr(module, "_old_forward"):
module._old_forward = module.__class__.forward.__get__(module, module.__class__)
Loading