Skip to content

Commit

Permalink
fix SqueezeBertForMaskedLM (huggingface#8479)
Browse files Browse the repository at this point in the history
  • Loading branch information
forresti authored and fabiocapsouza committed Nov 15, 2020
1 parent f100a75 commit f6daf46
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 4 deletions.
2 changes: 1 addition & 1 deletion model_cards/squeezebert/squeezebert-uncased/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The authors found that SqueezeBERT is 4.3x faster than `bert-base-uncased` on a
The model is pretrained using the Masked Language Model (MLM) and Sentence Order Prediction (SOP) tasks.
(Author's note: If you decide to pretrain your own model, and you prefer to train with MLM only, that should work too.)

The SqueezeBERT paper presents 2 approaches to finetuning the model:
From the SqueezeBERT paper:
> We pretrain SqueezeBERT from scratch (without distillation) using the [LAMB](https://arxiv.org/abs/1904.00962) optimizer, and we employ the hyperparameters recommended by the LAMB authors: a global batch size of 8192, a learning rate of 2.5e-3, and a warmup proportion of 0.28. Following the LAMB paper's recommendations, we pretrain for 56k steps with a maximum sequence length of 128 and then for 6k steps with a maximum sequence length of 512.
## Finetuning
Expand Down
56 changes: 53 additions & 3 deletions src/transformers/modeling_squeezebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,53 @@ def forward(self, hidden_states):
return pooled_output


class SqueezeBertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states


class SqueezeBertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = SqueezeBertPredictionHeadTransform(config)

# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

self.bias = nn.Parameter(torch.zeros(config.vocab_size))

# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states


class SqueezeBertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = SqueezeBertLMPredictionHead(config)

def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores


class SqueezeBertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
Expand Down Expand Up @@ -594,16 +641,19 @@ def forward(

@add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top. """, SQUEEZEBERT_START_DOCSTRING)
class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):

authorized_missing_keys = [r"predictions.decoder.bias"]

def __init__(self, config):
super().__init__(config)

self.transformer = SqueezeBertModel(config)
self.lm_head = nn.Linear(config.embedding_size, config.vocab_size)
self.cls = SqueezeBertOnlyMLMHead(config)

self.init_weights()

def get_output_embeddings(self):
return self.lm_head
return self.cls.predictions.decoder

@add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -646,7 +696,7 @@ def forward(
)

sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
prediction_scores = self.cls(sequence_output)

masked_lm_loss = None
if labels is not None:
Expand Down

0 comments on commit f6daf46

Please sign in to comment.