Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SqueezeBERT for masked language model #8479

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion model_cards/squeezebert/squeezebert-uncased/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The authors found that SqueezeBERT is 4.3x faster than `bert-base-uncased` on a
The model is pretrained using the Masked Language Model (MLM) and Sentence Order Prediction (SOP) tasks.
(Author's note: If you decide to pretrain your own model, and you prefer to train with MLM only, that should work too.)

The SqueezeBERT paper presents 2 approaches to finetuning the model:
From the SqueezeBERT paper:
> We pretrain SqueezeBERT from scratch (without distillation) using the [LAMB](https://arxiv.org/abs/1904.00962) optimizer, and we employ the hyperparameters recommended by the LAMB authors: a global batch size of 8192, a learning rate of 2.5e-3, and a warmup proportion of 0.28. Following the LAMB paper's recommendations, we pretrain for 56k steps with a maximum sequence length of 128 and then for 6k steps with a maximum sequence length of 512.

## Finetuning
Expand Down
56 changes: 53 additions & 3 deletions src/transformers/modeling_squeezebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,53 @@ def forward(self, hidden_states):
return pooled_output


class SqueezeBertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states


class SqueezeBertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = SqueezeBertPredictionHeadTransform(config)

# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

self.bias = nn.Parameter(torch.zeros(config.vocab_size))

# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states


class SqueezeBertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = SqueezeBertLMPredictionHead(config)

def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores


class SqueezeBertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
Expand Down Expand Up @@ -594,16 +641,19 @@ def forward(

@add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top. """, SQUEEZEBERT_START_DOCSTRING)
class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):

authorized_missing_keys = [r"predictions.decoder.bias"]

def __init__(self, config):
super().__init__(config)

self.transformer = SqueezeBertModel(config)
self.lm_head = nn.Linear(config.embedding_size, config.vocab_size)
self.cls = SqueezeBertOnlyMLMHead(config)

self.init_weights()

def get_output_embeddings(self):
return self.lm_head
return self.cls.predictions.decoder

@add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -646,7 +696,7 @@ def forward(
)

sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
prediction_scores = self.cls(sequence_output)

masked_lm_loss = None
if labels is not None:
Expand Down