Skip to content

Commit

Permalink
Fix compatibility of adapters with HF Accelerate auto device-mapping (#…
Browse files Browse the repository at this point in the history
…678)

Adapters currently does not work correctly with passing
`device_map="auto"` in a model's `from_pretrained()`. Device
auto-mapping is handled by HF accelerate, which wraps the original
module forward method.

This PR fixes compatibility of Adapters' post-hoc model wrapping with
Accelerate's device auto-mapping via wrapping the forward pass.

Fixing this is required for enabling quantized training of adapters
(bottleneck & prefix-tuning) in #663.
  • Loading branch information
calpt authored Apr 22, 2024
1 parent a15ac45 commit 233db31
Show file tree
Hide file tree
Showing 16 changed files with 73 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/contributing/adding_adapters_to_a_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Now that we have discussed the purpose of every file in `src/adapters/models/<mo
- Create a new class in `src/adapters/models/<model_type>/modeling_<model_type>.py` with the name `<class>WithAdapters`. This class should derive from the corresponding mixin and HF class.
- Copy the function you want to change into this class and modify it.
- e.g., the `forward` method of the `BertSelfAttention` class must be adapted to support prefix tuning. We therefore create a class `BertSelfAttentionWithAdapters(BertSelfAttentionAdaptersMixin, BertSelfAttention)`, copy the forward method into it and modify it.
- if the `forward` method of a module is copied and modified, make sure to call `adapters.utils.patch_forward()` in the module's `init_adapters()` method. This ensures adapters work correctly with the `accelerate` package.
4. **Modify MODEL_MIXIN_MAPPING**
- For each mixin whose class was not copied into `modeling_<model_type>.py`, add the mixin/class combination into `MODEL_MIXIN_MAPPING` in the file `src/adapters/models/__init__.py`.
5. **Create the adapter model:**
Expand Down
7 changes: 6 additions & 1 deletion src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .methods.modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock, init_shared_parameters
from .methods.prefix_tuning import PrefixTuningLayer, PrefixTuningPool
from .methods.prompt_tuning import PromptTuningLayer
from .utils import EMBEDDING_FILE, TOKENIZER_PATH, get_adapter_config_hash, inherit_doc
from .utils import EMBEDDING_FILE, TOKENIZER_PATH, get_adapter_config_hash, inherit_doc, patch_forward
from .wrappers.configuration import SUBMODEL_NAMES, init_adapters_config


Expand Down Expand Up @@ -1258,6 +1258,11 @@ def save_pretrained(
class ModelBaseAdaptersMixin(ModelAdaptersMixin):
add_base_adapters = True

def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=True):
super().init_adapters(model_config, adapters_config, add_prefix_tuning_pool)

patch_forward(self)

def post_embedding_forward(self, module, args, embedding_output):
if isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin):
embedding_output = self.invertible_adapters_forward(embedding_output)
Expand Down
4 changes: 4 additions & 0 deletions src/adapters/models/albert/mixin_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin
from ...utils import patch_forward


class AlbertAttentionAdaptersMixin:
Expand All @@ -23,6 +24,7 @@ def init_adapters(self, model_config, adapters_config):
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)
patch_forward(self)


class AlbertEncoderLayerAdaptersMixin:
Expand All @@ -40,6 +42,8 @@ def init_adapters(self, model_config, adapters_config):

self.attention.location_key = "self"

patch_forward(self)


class AlbertModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin):
"""Adds adapters to the AlbertModel module."""
Expand Down
7 changes: 7 additions & 0 deletions src/adapters/models/bart/mixin_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
InvertibleAdaptersWrapperMixin,
ModelBaseAdaptersMixin,
)
from ...utils import patch_forward


class BartAttentionAdaptersMixin:
Expand All @@ -28,6 +29,7 @@ def init_adapters(self, model_config, adapters_config):
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)
patch_forward(self)


class BartEncoderLayerAdaptersMixin:
Expand All @@ -44,6 +46,8 @@ def init_adapters(self, model_config, adapters_config):
self.attention_adapters = BottleneckLayer("mh_adapter")
self.output_adapters = BottleneckLayer("output_adapter")

patch_forward(self)


class BartDecoderLayerAdaptersMixin(BartEncoderLayerAdaptersMixin):
"""Adds adapters to the BartDecoderLayer module of BART."""
Expand All @@ -65,6 +69,9 @@ class BartEncoderAdaptersMixin(InvertibleAdaptersMixin):
class BartDecoderAdaptersMixin:
"""Adds adapters to the BartDecoder module of BART."""

def init_adapters(self, model_config, adapters_config):
patch_forward(self)

def forward(
self, input_ids: torch.LongTensor = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, **kwargs
):
Expand Down
3 changes: 3 additions & 0 deletions src/adapters/models/beit/mixin_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import ModelBaseAdaptersMixin
from ...utils import patch_forward


class BeitSelfAttentionAdaptersMixin:
Expand All @@ -20,6 +21,7 @@ def init_adapters(self, model_config, adapters_config):
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)
patch_forward(self)


class BeitIntermediateAdaptersMixin:
Expand All @@ -40,6 +42,7 @@ class BeitLayerAdaptersMixin:
def init_adapters(self, model_config, adapters_config):
self.attention_adapters = BottleneckLayer("mh_adapter")
self.output_adapters = BottleneckLayer("output_adapter")
patch_forward(self)


class BeitModelAdaptersMixin(ModelBaseAdaptersMixin):
Expand Down
4 changes: 4 additions & 0 deletions src/adapters/models/bert/mixin_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin
from ...utils import patch_forward


logger = logging.getLogger(__name__)
Expand All @@ -25,6 +26,7 @@ def init_adapters(self, model_config, adapters_config):
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)
patch_forward(self)


# For backwards compatibility, BertSelfOutput inherits directly from BottleneckLayer
Expand All @@ -37,6 +39,7 @@ def __init__(self):
def init_adapters(self, model_config, adapters_config):
self.location_key = "mh_adapter"
super().init_adapters(model_config, adapters_config)
patch_forward(self)


# For backwards compatibility, BertOutput inherits directly from BottleneckLayer
Expand All @@ -49,6 +52,7 @@ def __init__(self):
def init_adapters(self, model_config, adapters_config):
self.location_key = "output_adapter"
super().init_adapters(model_config, adapters_config)
patch_forward(self)


class BertLayerAdaptersMixin:
Expand Down
4 changes: 4 additions & 0 deletions src/adapters/models/clip/mixin_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
InvertibleAdaptersWrapperMixin,
ModelBaseAdaptersMixin,
)
from ...utils import patch_forward


class CLIPAttentionAdaptersMixin:
Expand All @@ -27,6 +28,7 @@ def init_adapters(self, model_config, adapters_config):
self.prefix_tuning = PrefixTuningLayer(
"self_prefix", model_config, adapters_config, add_model_type_to_key=True
)
patch_forward(self)


class CLIPEncoderLayerAdaptersMixin:
Expand All @@ -40,6 +42,8 @@ def init_adapters(self, model_config, adapters_config):
self.attention_adapters = BottleneckLayer("mh_adapter")
self.output_adapters = BottleneckLayer("output_adapter")

patch_forward(self)


class CLIPEncoderAdaptersMixin:
"""Adds adapters to the CLIPEncoder module of CLIP."""
Expand Down
2 changes: 2 additions & 0 deletions src/adapters/models/deberta/mixin_deberta.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ...methods.lora import LoRAMergedLinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...utils import patch_forward


class DebertaSelfAttentionAdaptersMixin:
Expand All @@ -12,3 +13,4 @@ def init_adapters(self, model_config, adapters_config):
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)
patch_forward(self)
2 changes: 2 additions & 0 deletions src/adapters/models/deberta_v2/mixin_deberta_v2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...utils import patch_forward


class DebertaV2SelfAttentionAdaptersMixin:
Expand All @@ -14,3 +15,4 @@ def init_adapters(self, model_config, adapters_config):
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)
patch_forward(self)
7 changes: 7 additions & 0 deletions src/adapters/models/distilbert/mixin_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin
from ...utils import patch_forward


class DistilBertMultiHeadSelfAttentionMixin:
Expand All @@ -18,6 +19,7 @@ def init_adapters(self, model_config, adapters_config):
self.v_lin = LoRALinear.wrap(self.v_lin, "selfattn", model_config, adapters_config, attn_key="v")

self.prefix_tuning = PrefixTuningLayer("self", model_config, adapters_config)
patch_forward(self)


class DistilBertTransfomerBlockAdaptersMixin:
Expand All @@ -31,10 +33,15 @@ def init_adapters(self, model_config, adapters_config):
self.attention_adapters = BottleneckLayer("mh_adapter")
self.output_adapters = BottleneckLayer("output_adapter")

patch_forward(self)


class DistilBertTransformerAdaptersMixin:
"""Adds adapters to the Transformer module of DistilBert."""

def init_adapters(self, model_config, adapters_config):
patch_forward(self)

def forward(self, *args, **kwargs):
if hasattr(self, "pre_forward_fn"):
kwargs["x"] = self.pre_forward_fn(self, kwargs["x"])
Expand Down
5 changes: 5 additions & 0 deletions src/adapters/models/gpt2/mixin_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ...methods.lora import LoRALinear, LoRAMergedLinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin
from ...utils import patch_forward


class GPT2AttentionAdaptersMixin:
Expand All @@ -26,6 +27,8 @@ def init_adapters(self, model_config, adapters_config):
location_key = "cross_prefix" if self.is_cross_attention else "self_prefix"
self.prefix_tuning = PrefixTuningLayer(location_key, model_config, adapters_config)

patch_forward(self)


class GPT2DecoderBlockAdaptersMixin:
"""Adds adapters to the TransformerBlock module of DistilBert."""
Expand All @@ -52,6 +55,8 @@ def init_adapters(self, model_config, adapters_config):
self.attention_adapters = BottleneckLayer("mh_adapter")
self.output_adapters = BottleneckLayer("output_adapter")

patch_forward(self)


class GPT2ModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin):
support_prompt_tuning = False
Expand Down
4 changes: 4 additions & 0 deletions src/adapters/models/gptj/mixin_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin
from ...utils import patch_forward


class GPTJAttentionAdaptersMixin:
Expand All @@ -20,6 +21,7 @@ def init_adapters(self, model_config, adapters_config):
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)
patch_forward(self)


class GPTJMLPAdaptersMixin:
Expand All @@ -36,6 +38,8 @@ def init_adapters(self, model_config, adapters_config):
self.attention_adapters = BottleneckLayer("mh_adapter")
self.output_adapters = BottleneckLayer("output_adapter")

patch_forward(self)


class GPTJModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin):
support_prompt_tuning = False
Expand Down
5 changes: 5 additions & 0 deletions src/adapters/models/llama/mixin_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin
from ...utils import patch_forward


class LlamaAttentionMixin:
Expand All @@ -16,6 +17,8 @@ def init_adapters(self, model_config, adapters_config):

self.prefix_tuning = PrefixTuningLayer("self_prefix", model_config, adapters_config)

patch_forward(self)


class LlamaDecoderLayerMixin:
def init_adapters(self, model_config, adapters_config):
Expand All @@ -26,6 +29,8 @@ def init_adapters(self, model_config, adapters_config):
self.attention_adapters = BottleneckLayer("mh_adapter")
self.output_adapters = BottleneckLayer("output_adapter")

patch_forward(self)


class LlamaModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin):
support_prompt_tuning = False
Expand Down
5 changes: 5 additions & 0 deletions src/adapters/models/t5/mixin_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ModelBaseAdaptersMixin,
ModelWithHeadsAdaptersMixin,
)
from ...utils import patch_forward


class T5AttentionAdaptersMixin:
Expand All @@ -27,6 +28,7 @@ def init_adapters(self, model_config, adapters_config):
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)
patch_forward(self)


class T5SelfAttentionLayerAdaptersMixin(BottleneckLayer):
Expand All @@ -36,6 +38,7 @@ def __init__(self):
def init_adapters(self, model_config, adapters_config):
self.location_key = "mh_adapter"
super().init_adapters(model_config, adapters_config)
patch_forward(self)


class T5CrossAttentionLayerAdaptersMixin(BottleneckLayer):
Expand All @@ -46,6 +49,7 @@ def init_adapters(self, model_config, adapters_config):
self.location_key = "cross_adapter"
self.EncDecAttention.location_key = "cross"
super().init_adapters(model_config, adapters_config)
patch_forward(self)


class T5FFLayerAdaptersMixin(BottleneckLayer):
Expand Down Expand Up @@ -82,6 +86,7 @@ class T5StackAdaptersMixin(InvertibleAdaptersMixin):
def init_adapters(self, model_config, adapters_config):
if not self.is_decoder:
InvertibleAdaptersMixin.init_adapters(self, self.config, adapters_config)
patch_forward(self)

def post_embedding_forward(self, embedding_output):
embedding_output = self.invertible_adapters_forward(embedding_output)
Expand Down
5 changes: 5 additions & 0 deletions src/adapters/models/vit/mixin_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ...methods.lora import LoRALinear
from ...methods.prefix_tuning import PrefixTuningLayer
from ...model_mixin import ModelBaseAdaptersMixin
from ...utils import patch_forward


class ViTSelfAttentionAdaptersMixin:
Expand All @@ -20,6 +21,7 @@ def init_adapters(self, model_config, adapters_config):
self.prefix_tuning = PrefixTuningLayer(
self.location_key + "_prefix" if self.location_key else None, model_config, adapters_config
)
patch_forward(self)


class ViTIntermediateAdaptersMixin:
Expand All @@ -37,13 +39,16 @@ def init_adapters(self, model_config, adapters_config):
# Wrap layers for LoRA
self.dense = LoRALinear.wrap(self.dense, "output", model_config, adapters_config)

patch_forward(self)


# Unlike BERT, self attention adapters are added to Layer module in ViT
class ViTLayerAdaptersMixin:
"""Adds adapters to the ViTSelfOutput module."""

def init_adapters(self, model_config, adapters_config):
self.attention_adapters = BottleneckLayer("mh_adapter")
patch_forward(self)


class ViTModelAdaptersMixin(ModelBaseAdaptersMixin):
Expand Down
9 changes: 9 additions & 0 deletions src/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,3 +864,12 @@ def prefix_attention_mask(attention_mask, dim: int = 3, prefix_value: int = 0):
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=dim)

return attention_mask


def patch_forward(module: torch.nn.Module):
# HF Accelerate's `add_hook_to_module()` replaces the module forward method with a wrapper
# and stores the original forward method in `_old_forward`. For this to work with Adapters' post-hook wrapping,
# we need to explicitly set to potentially overriden forward methods on adapter init.
# The `add_hook_to_module()` method is e.g. used for `device_map="auto"` in the `PreTrainedModel.from_pretrained()` method.
if hasattr(module, "_old_forward"):
module._old_forward = module.__class__.forward.__get__(module, module.__class__)

0 comments on commit 233db31

Please sign in to comment.