From de9863b50b2b7d1a1aa1c9a371902f787cc1f7a8 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 26 Aug 2020 23:11:40 -0400 Subject: [PATCH 1/2] s2s distillation uses AutoModelForSeqToSeqLM --- examples/seq2seq/distillation.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/examples/seq2seq/distillation.py b/examples/seq2seq/distillation.py index 9cf6cbd818ac45..67e695ef99dbd8 100644 --- a/examples/seq2seq/distillation.py +++ b/examples/seq2seq/distillation.py @@ -10,7 +10,7 @@ from torch.nn import functional as F from lightning_base import generic_train -from transformers import BartConfig, BartForConditionalGeneration, MBartTokenizer, T5Config, T5ForConditionalGeneration +from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration try: @@ -74,22 +74,22 @@ def sanity_check_gradients(self): def pre_init(self, hparams): self.output_dir = Path(hparams.output_dir) self.output_dir.mkdir(exist_ok=True) - teacher = BartForConditionalGeneration.from_pretrained(hparams.teacher).eval() + teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval() student_updates = { "decoder_layers": hparams.student_decoder_layers, "encoder_layers": hparams.student_encoder_layers, } if hparams.length_penalty != -1: student_updates["length_penalty"] = hparams.length_penalty - d_layers_to_copy = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers) + d_layers_to_copy: List = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers) e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers) hparams.d_layer_to_copy = d_layers_to_copy hparams.e_layer_to_copy = e_layers_to_copy kw = teacher.config.to_diff_dict() kw.update(student_updates) # Copy weights - student_cfg = BartConfig(**kw) - student = BartForConditionalGeneration(student_cfg) + student_cfg = teacher.config_class(**kw) + student = type(teacher)(student_cfg) student, _ = init_student(student, teacher) save_dir = self.output_dir.joinpath("student") self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher) @@ -252,7 +252,6 @@ class BartTranslationDistiller(BartSummarizationDistiller): def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) - assert isinstance(self.tokenizer, MBartTokenizer) assert hparams.src_lang is not None assert hparams.tgt_lang is not None self.dataset_kwargs["src_lang"] = hparams.src_lang From 88974cc4bb536ab0e0ec9dc4713d8a19b3320eee Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Wed, 26 Aug 2020 23:16:51 -0400 Subject: [PATCH 2/2] assert model_type --- examples/seq2seq/test_seq2seq_examples.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 1f70cbd312ac57..2f397c7adcba08 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -186,6 +186,7 @@ def test_distill_mbart(self): tgt_lang="ro_RO", ) model = self._test_distiller_cli(updates, check_contents=False) + assert model.model.config.model_type == "mbart" ckpts = list(Path(model.output_dir).glob("*.ckpt")) self.assertEqual(1, len(ckpts))