From 4ceaddaaf0f16b3f6c83327444992c0cc385f42a Mon Sep 17 00:00:00 2001 From: Pengcheng He <38195654+BigBird01@users.noreply.github.com> Date: Mon, 15 Feb 2021 02:13:04 -0800 Subject: [PATCH] Fix v2 model loading issue (#10129) --- .../models/auto/configuration_auto.py | 5 ++-- .../models/deberta/modeling_deberta.py | 13 ++++++++++ .../deberta_v2/configuration_deberta_v2.py | 10 ++++---- .../models/deberta_v2/modeling_deberta_v2.py | 25 ++++++++++++++----- .../deberta_v2/tokenization_deberta_v2.py | 24 +++++++++--------- 5 files changed, 51 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index afb02fc36aaa0d..a5dcff3b3aab7c 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -140,7 +140,7 @@ ("reformer", ReformerConfig), ("longformer", LongformerConfig), ("roberta", RobertaConfig), - ("deberta_v2", DebertaV2Config), + ("deberta-v2", DebertaV2Config), ("deberta", DebertaConfig), ("flaubert", FlaubertConfig), ("fsmt", FSMTConfig), @@ -202,8 +202,8 @@ ("encoder-decoder", "Encoder decoder"), ("funnel", "Funnel Transformer"), ("lxmert", "LXMERT"), + ("deberta-v2", "DeBERTa-v2"), ("deberta", "DeBERTa"), - ("deberta_v2", "DeBERTa-v2"), ("layoutlm", "LayoutLM"), ("dpr", "DPR"), ("rag", "RAG"), @@ -370,7 +370,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): {'foo': False} """ config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs) - if "model_type" in config_dict: config_class = CONFIG_MAPPING[config_dict["model_type"]] return config_class.from_dict(config_dict, **kwargs) diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 60b9546379a2f1..2b2644a4b8a562 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -762,6 +762,10 @@ class DebertaPreTrainedModel(PreTrainedModel): _keys_to_ignore_on_load_missing = ["position_ids"] _keys_to_ignore_on_load_unexpected = ["position_embeddings"] + def __init__(self, config): + super().__init__(config) + self._register_load_state_dict_pre_hook(self._pre_load_hook) + def _init_weights(self, module): """ Initialize the weights """ if isinstance(module, (nn.Linear, nn.Embedding)): @@ -771,6 +775,15 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() + def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + self_state = self.state_dict() + if ('classifier.weight' in self_state) and ('classifier.weight' in state_dict) and \ + self_state['classifier.weight'].size() != state_dict['classifier.weight'].size(): + logger.warning('Ignore mismatched classifer head.') + del state_dict['classifier.weight'] + if 'classifier.bias' in state_dict: + del state_dict['classifier.bias'] DEBERTA_START_DOCSTRING = r""" The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention diff --git a/src/transformers/models/deberta_v2/configuration_deberta_v2.py b/src/transformers/models/deberta_v2/configuration_deberta_v2.py index 128a8701352266..87f439e5b50aae 100644 --- a/src/transformers/models/deberta_v2/configuration_deberta_v2.py +++ b/src/transformers/models/deberta_v2/configuration_deberta_v2.py @@ -21,10 +21,10 @@ logger = logging.get_logger(__name__) DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "microsoft/deberta-xlarge-v2": "https://huggingface.co/microsoft/deberta-xlarge-v2/resolve/main/config.json", - "microsoft/deberta-xxlarge-v2": "https://huggingface.co/microsoft/deberta-xxlarge-v2/resolve/main/config.json", - "microsoft/deberta-xlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xlarge-v2-mnli/resolve/main/config.json", - "microsoft/deberta-xxlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xxlarge-v2-mnli/resolve/main/config.json", + "microsoft/deberta-v2-xlarge": "https://huggingface.co/microsoft/deberta-v2-xlarge/resolve/main/config.json", + "microsoft/deberta-v2-xxlarge": "https://huggingface.co/microsoft/deberta-v2-xxlarge/resolve/main/config.json", + "microsoft/deberta-v2-xlarge-mnli": "https://huggingface.co/microsoft/deberta-v2-xlarge-mnli/resolve/main/config.json", + "microsoft/deberta-v2-xxlarge-mnli": "https://huggingface.co/microsoft/deberta-v2-xxlarge-mnli/resolve/main/config.json", } @@ -83,7 +83,7 @@ class DebertaV2Config(PretrainedConfig): layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): The epsilon used by the layer normalization layers. """ - model_type = "deberta" + model_type = "deberta-v2" def __init__( self, diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 834b542ba505cf..ca466c888fc6e3 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -40,13 +40,13 @@ _CONFIG_FOR_DOC = "DebertaV2Config" _TOKENIZER_FOR_DOC = "DebertaV2Tokenizer" -_CHECKPOINT_FOR_DOC = "microsoft/deberta-xlarge-v2" +_CHECKPOINT_FOR_DOC = "microsoft/deberta-v2-xlarge" DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "microsoft/deberta-xlarge-v2", - "microsoft/deberta-xxlarge-v2", - "microsoft/deberta-xlarge-v2-mnli", - "microsoft/deberta-xxlarge-v2-mnli", + "microsoft/deberta-v2-xlarge", + "microsoft/deberta-v2-xxlarge", + "microsoft/deberta-v2-xlarge-mnli", + "microsoft/deberta-v2-xxlarge-mnli", ] @@ -897,10 +897,14 @@ class DebertaV2PreTrainedModel(PreTrainedModel): """ config_class = DebertaV2Config - base_model_prefix = "deberta" + base_model_prefix = "deberta-v2" _keys_to_ignore_on_load_missing = ["position_ids"] _keys_to_ignore_on_load_unexpected = ["position_embeddings"] + def __init__(self, config): + super().__init__(config) + self._register_load_state_dict_pre_hook(self._pre_load_hook) + def _init_weights(self, module): """ Initialize the weights """ if isinstance(module, (nn.Linear, nn.Embedding)): @@ -910,6 +914,15 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() + def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + self_state = self.state_dict() + if ('classifier.weight' in self_state) and ('classifier.weight' in state_dict) and \ + self_state['classifier.weight'].size() != state_dict['classifier.weight'].size(): + logger.warning('Ignore mismatched classifer head.') + del state_dict['classifier.weight'] + if 'classifier.bias' in state_dict: + del state_dict['classifier.bias'] DEBERTA_START_DOCSTRING = r""" The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention diff --git a/src/transformers/models/deberta_v2/tokenization_deberta_v2.py b/src/transformers/models/deberta_v2/tokenization_deberta_v2.py index 97b9c04150eea9..564705fe5264f2 100644 --- a/src/transformers/models/deberta_v2/tokenization_deberta_v2.py +++ b/src/transformers/models/deberta_v2/tokenization_deberta_v2.py @@ -26,25 +26,25 @@ PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "microsoft/deberta-xlarge-v2": "https://huggingface.co/microsoft/deberta-xlarge-v2/resolve/main/spm.model", - "microsoft/deberta-xxlarge-v2": "https://huggingface.co/microsoft/deberta-xxlarge-v2/resolve/main/spm.model", - "microsoft/deberta-xlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xlarge-v2-mnli/resolve/main/spm.model", - "microsoft/deberta-xxlarge-v2-mnli": "https://huggingface.co/microsoft/deberta-xxlarge-v2-mnli/resolve/main/spm.model", + "microsoft/deberta-v2-xlarge": "https://huggingface.co/microsoft/deberta-v2-xlarge/resolve/main/spm.model", + "microsoft/deberta-v2-xxlarge": "https://huggingface.co/microsoft/deberta-v2-xxlarge/resolve/main/spm.model", + "microsoft/deberta-v2-xlarge-mnli": "https://huggingface.co/microsoft/deberta-v2-xlarge-mnli/resolve/main/spm.model", + "microsoft/deberta-v2-xxlarge-mnli": "https://huggingface.co/microsoft/deberta-v2-xxlarge-mnli/resolve/main/spm.model", } } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "microsoft/deberta-xlarge-v2": 512, - "microsoft/deberta-xxlarge-v2": 512, - "microsoft/deberta-xlarge-v2-mnli": 512, - "microsoft/deberta-xxlarge-v2-mnli": 512, + "microsoft/deberta-v2-xlarge": 512, + "microsoft/deberta-v2-xxlarge": 512, + "microsoft/deberta-v2-xlarge-mnli": 512, + "microsoft/deberta-v2-xxlarge-mnli": 512, } PRETRAINED_INIT_CONFIGURATION = { - "microsoft/deberta-xlarge-v2": {"do_lower_case": False}, - "microsoft/deberta-xxlarge-v2": {"do_lower_case": False}, - "microsoft/deberta-xlarge-v2-mnli": {"do_lower_case": False}, - "microsoft/deberta-xxlarge-v2-mnli": {"do_lower_case": False}, + "microsoft/deberta-v2-xlarge": {"do_lower_case": False}, + "microsoft/deberta-v2-xxlarge": {"do_lower_case": False}, + "microsoft/deberta-v2-xlarge-mnli": {"do_lower_case": False}, + "microsoft/deberta-v2-xxlarge-mnli": {"do_lower_case": False}, } VOCAB_FILES_NAMES = {"vocab_file": "spm.model"}