From 921c87a09f66aed3c4deed2d2800a683b6089d26 Mon Sep 17 00:00:00 2001 From: cronoik Date: Sun, 25 Apr 2021 11:45:46 +0200 Subject: [PATCH] EncoderDecoderConfigs should not create new objects (#11300) * removes the creation of separate config objects and uses the existing ones instead+overwrite resize_token_embeddings from parent class because it is not working for the EncoderDecoderModel * rollback to current version of the huggingface master branch * reworked version that ties the encoder and decoder config of the parent encoderdecoder instance * overwrite of resize_token_embeddings throws an error now * review comment suggestion Co-authored-by: Suraj Patil * implemented warning in case encoderdecoder is created with differing configs of encoderdecoderconfig and decoderconfig or encoderconfig * added test to avoid diverging configs of wrapper class and wrapped classes * Update src/transformers/models/encoder_decoder/modeling_encoder_decoder.py * make style Co-authored-by: Suraj Patil Co-authored-by: Patrick von Platen --- .../modeling_encoder_decoder.py | 21 +++++++++++ tests/test_modeling_encoder_decoder.py | 36 +++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index bcb85df33528cd..3696cf9167b18d 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -175,6 +175,21 @@ def __init__( self.encoder = encoder self.decoder = decoder + + if self.encoder.config.to_dict() != self.config.encoder.to_dict(): + logger.warning( + f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config: {self.config.encoder}" + ) + if self.decoder.config.to_dict() != self.config.decoder.to_dict(): + logger.warning( + f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config: {self.config.decoder}" + ) + + # make sure that the individual model's config refers to the shared config + # so that the updates to the config will be synced + self.encoder.config = self.config.encoder + self.decoder.config = self.config.decoder + assert ( self.encoder.get_output_embeddings() is None ), "The encoder {} should not have a LM Head. Please use a model without LM Head" @@ -458,6 +473,12 @@ def prepare_inputs_for_generation( } return input_dict + def resize_token_embeddings(self, *args, **kwargs): + raise NotImplementedError( + "Resizing the embedding layers via the EncoderDecoderModel directly is not supported." + "Please use the respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or model.decoder.resize_token_embeddings(...))" + ) + def _reorder_cache(self, past, beam_idx): # apply decoder cache reordering here return self.decoder._reorder_cache(past, beam_idx) diff --git a/tests/test_modeling_encoder_decoder.py b/tests/test_modeling_encoder_decoder.py index 338de796b91967..26e381180d9d0d 100644 --- a/tests/test_modeling_encoder_decoder.py +++ b/tests/test_modeling_encoder_decoder.py @@ -34,6 +34,7 @@ import torch from transformers import ( + AutoConfig, AutoTokenizer, BartForCausalLM, BertGenerationDecoder, @@ -884,3 +885,38 @@ def get_pretrained_model(self): def test_encoder_decoder_model_shared_weights(self): pass + + +@require_torch +class EncoderDecoderModelTest(unittest.TestCase): + def get_from_encoderdecoder_pretrained_model(self): + return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased") + + def get_decoder_config(self): + config = AutoConfig.from_pretrained("bert-base-uncased") + config.is_decoder = True + config.add_cross_attention = True + return config + + def get_encoderdecoder_model(self): + return EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16") + + def get_encoder_decoder_models(self): + encoder_model = BertModel.from_pretrained("bert-base-uncased") + decoder_model = BertLMHeadModel.from_pretrained("bert-base-uncased", config=self.get_decoder_config()) + return {"encoder": encoder_model, "decoder": decoder_model} + + def _check_configuration_tie(self, model): + assert id(model.decoder.config) == id(model.config.decoder) + assert id(model.encoder.config) == id(model.config.encoder) + + @slow + def test_configuration_tie(self): + model = self.get_from_encoderdecoder_pretrained_model() + self._check_configuration_tie(model) + + model = EncoderDecoderModel(**self.get_encoder_decoder_models()) + self._check_configuration_tie(model) + + model = self.get_encoderdecoder_model() + self._check_configuration_tie(model)