Skip to content

Commit

Permalink
Style
Browse files Browse the repository at this point in the history
  • Loading branch information
LysandreJik committed Feb 19, 2021
1 parent 24ced6a commit 086698e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
8 changes: 3 additions & 5 deletions src/transformers/models/deberta/modeling_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,7 @@ def __init__(self, config):
self.max_relative_positions = config.max_position_embeddings
self.rel_embeddings = nn.Embedding(self.max_relative_positions * 2, config.hidden_size)

def get_rel_embedding(self):
rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
return rel_embeddings
self.rel_embeddings_weights = self.rel_embeddings.weight if self.relative_attention else None

def get_attention_mask(self, attention_mask):
if attention_mask.dim() <= 2:
Expand Down Expand Up @@ -400,7 +398,7 @@ def forward(
next_kv = hidden_states[0]
else:
next_kv = hidden_states
rel_embeddings = self.get_rel_embedding()
rel_embeddings = self.rel_embeddings_weights
for i, layer_module in enumerate(self.layer):

if output_hidden_states:
Expand Down Expand Up @@ -946,7 +944,7 @@ def forward(
hidden_states = encoded_layers[-2]
layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
query_states = encoded_layers[-1]
rel_embeddings = self.encoder.get_rel_embedding()
rel_embeddings = self.encoder.rel_embeddings_weights
attention_mask = self.encoder.get_attention_mask(attention_mask)
rel_pos = self.encoder.get_rel_pos(embedding_output)
for layer in layers[1:]:
Expand Down
9 changes: 8 additions & 1 deletion src/transformers/models/deberta_v2/modeling_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,13 +890,20 @@ def _init_weights(self, module):
module.bias.data.zero_()

def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
"""
Removes the classifier if it doesn't have the correct number of labels.
"""
self_state = self.state_dict()
if (
("classifier.weight" in self_state)
and ("classifier.weight" in state_dict)
and self_state["classifier.weight"].size() != state_dict["classifier.weight"].size()
):
logger.warning("Ignore mismatched classifer head.")
logger.warning(
f"The checkpoint classifier head has a shape {state_dict['classifier.weight'].size()} and this model "
f"classifier head has a shape {self_state['classifier.weight'].size()}. Ignoring the checkpoint "
f"weights. You should train your model on new data."
)
del state_dict["classifier.weight"]
if "classifier.bias" in state_dict:
del state_dict["classifier.bias"]
Expand Down

0 comments on commit 086698e

Please sign in to comment.