diff --git a/src/transformers/adapters/model_mixin.py b/src/transformers/adapters/model_mixin.py index 786c500b59..5c4f1214f5 100644 --- a/src/transformers/adapters/model_mixin.py +++ b/src/transformers/adapters/model_mixin.py @@ -440,13 +440,14 @@ def freeze_model(self, freeze=True): param.requires_grad = not freeze self.model_freezed = freeze - def pre_transformer_forward(self): + def pre_transformer_forward(self, **kwargs): """ This method should be called by every adapter-implementing model at the very beginning of the forward() method. """ # some warnings if we don't use available adapters - if not self.active_adapters and self.has_adapters(): - logger.warning("There are adapters available but none are passed to model.forward") + active_adapters = self.active_adapters or kwargs.get("adapter_names", None) + if not active_adapters and self.has_adapters(): + logger.warning("There are adapters available but none are activated for the forward pass.") self.config.adapters.is_parallelized = False diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index b27b92eb5e..16dadb62aa 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1174,7 +1174,7 @@ def forward( ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - self.pre_transformer_forward() + self.pre_transformer_forward(**kwargs) if encoder_outputs is None: encoder_outputs = self.encoder( @@ -1714,7 +1714,7 @@ def __init__(self, config): self._init_adapter_modules() def forward(self, *args, **kwargs): - self.pre_transformer_forward() + self.pre_transformer_forward(**kwargs) return self.decoder(*args, **kwargs) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 986d79b169..6311e4e470 100644 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -939,7 +939,7 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - self.pre_transformer_forward() + self.pre_transformer_forward(**kwargs) if self.config.is_decoder: use_cache = use_cache if use_cache is not None else self.config.use_cache diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index dde4324de9..e2e49b382f 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -483,7 +483,7 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - self.pre_transformer_forward() + self.pre_transformer_forward(**kwargs) if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 29b69166b5..337a05669c 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -627,7 +627,7 @@ def forward( return_dict=None, **kwargs ): - self.pre_transformer_forward() + self.pre_transformer_forward(**kwargs) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index d7e4abd690..1399e628d3 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1182,7 +1182,7 @@ def forward( ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - self.pre_transformer_forward() + self.pre_transformer_forward(**kwargs) # different to other models, MBart automatically creates decoder_input_ids from # input_ids if no decoder_input_ids are provided @@ -1726,7 +1726,7 @@ def __init__(self, config): self._init_adapter_modules() def forward(self, *args, **kwargs): - self.pre_transformer_forward() + self.pre_transformer_forward(**kwargs) return self.decoder(*args, **kwargs) diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 2635738d9b..bfe7e34cd7 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -782,7 +782,7 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - self.pre_transformer_forward() + self.pre_transformer_forward(**kwargs) if self.config.is_decoder: use_cache = use_cache if use_cache is not None else self.config.use_cache