Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix_mbart_tied_weights #26422

Merged
merged 2 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,11 @@ def get_encoder(self):
def get_decoder(self):
return self.decoder

def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.get_input_embeddings())
self._tie_or_clone_weights(self.decoder.embed_tokens, self.get_input_embeddings())

@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
Expand Down
37 changes: 37 additions & 0 deletions tests/models/mbart/test_modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,43 @@ def test_generate_fp16(self):
model.generate(input_ids, attention_mask=attention_mask)
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)

def test_ensure_weights_are_shared(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()

config.tie_word_embeddings = True
model = MBartForConditionalGeneration(config)

# MBart shares four weights.
# Not an issue to not have these correctly tied for torch.load, but it is an issue for safetensors.
self.assertEqual(
len(
{
model.get_output_embeddings().weight.data_ptr(),
model.get_input_embeddings().weight.data_ptr(),
model.base_model.decoder.embed_tokens.weight.data_ptr(),
model.base_model.encoder.embed_tokens.weight.data_ptr(),
}
),
1,
)

config.tie_word_embeddings = False
model = MBartForConditionalGeneration(config)

# MBart shares four weights.
# Not an issue to not have these correctly tied for torch.load, but it is an issue for safetensors.
self.assertEqual(
len(
{
model.get_output_embeddings().weight.data_ptr(),
model.get_input_embeddings().weight.data_ptr(),
model.base_model.decoder.embed_tokens.weight.data_ptr(),
model.base_model.encoder.embed_tokens.weight.data_ptr(),
}
),
2,
)


def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
Expand Down