From 086698efe18a50e863d43c2b3aeaae2e97e0ffe2 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 19 Feb 2021 17:08:39 -0500 Subject: [PATCH] Style --- src/transformers/models/deberta/modeling_deberta.py | 8 +++----- .../models/deberta_v2/modeling_deberta_v2.py | 9 ++++++++- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index d83a5959467c00..0f90ffa7b02fac 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -360,9 +360,7 @@ def __init__(self, config): self.max_relative_positions = config.max_position_embeddings self.rel_embeddings = nn.Embedding(self.max_relative_positions * 2, config.hidden_size) - def get_rel_embedding(self): - rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None - return rel_embeddings + self.rel_embeddings_weights = self.rel_embeddings.weight if self.relative_attention else None def get_attention_mask(self, attention_mask): if attention_mask.dim() <= 2: @@ -400,7 +398,7 @@ def forward( next_kv = hidden_states[0] else: next_kv = hidden_states - rel_embeddings = self.get_rel_embedding() + rel_embeddings = self.rel_embeddings_weights for i, layer_module in enumerate(self.layer): if output_hidden_states: @@ -946,7 +944,7 @@ def forward( hidden_states = encoded_layers[-2] layers = [self.encoder.layer[-1] for _ in range(self.z_steps)] query_states = encoded_layers[-1] - rel_embeddings = self.encoder.get_rel_embedding() + rel_embeddings = self.encoder.rel_embeddings_weights attention_mask = self.encoder.get_attention_mask(attention_mask) rel_pos = self.encoder.get_rel_pos(embedding_output) for layer in layers[1:]: diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 892b28daee4b39..0739c5212ff425 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -890,13 +890,20 @@ def _init_weights(self, module): module.bias.data.zero_() def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + """ + Removes the classifier if it doesn't have the correct number of labels. + """ self_state = self.state_dict() if ( ("classifier.weight" in self_state) and ("classifier.weight" in state_dict) and self_state["classifier.weight"].size() != state_dict["classifier.weight"].size() ): - logger.warning("Ignore mismatched classifer head.") + logger.warning( + f"The checkpoint classifier head has a shape {state_dict['classifier.weight'].size()} and this model " + f"classifier head has a shape {self_state['classifier.weight'].size()}. Ignoring the checkpoint " + f"weights. You should train your model on new data." + ) del state_dict["classifier.weight"] if "classifier.bias" in state_dict: del state_dict["classifier.bias"]