Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

s2s distillation uses AutoModelForSeqToSeqLM #6761

Merged
merged 2 commits into from
Aug 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions examples/seq2seq/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.nn import functional as F

from lightning_base import generic_train
from transformers import BartConfig, BartForConditionalGeneration, MBartTokenizer, T5Config, T5ForConditionalGeneration
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration


try:
Expand Down Expand Up @@ -74,22 +74,22 @@ def sanity_check_gradients(self):
def pre_init(self, hparams):
self.output_dir = Path(hparams.output_dir)
self.output_dir.mkdir(exist_ok=True)
teacher = BartForConditionalGeneration.from_pretrained(hparams.teacher).eval()
teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval()
student_updates = {
"decoder_layers": hparams.student_decoder_layers,
"encoder_layers": hparams.student_encoder_layers,
}
if hparams.length_penalty != -1:
student_updates["length_penalty"] = hparams.length_penalty
d_layers_to_copy = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
d_layers_to_copy: List = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
hparams.d_layer_to_copy = d_layers_to_copy
hparams.e_layer_to_copy = e_layers_to_copy
kw = teacher.config.to_diff_dict()
kw.update(student_updates)
# Copy weights
student_cfg = BartConfig(**kw)
student = BartForConditionalGeneration(student_cfg)
student_cfg = teacher.config_class(**kw)
student = type(teacher)(student_cfg)
student, _ = init_student(student, teacher)
save_dir = self.output_dir.joinpath("student")
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
Expand Down Expand Up @@ -252,7 +252,6 @@ class BartTranslationDistiller(BartSummarizationDistiller):

def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
assert isinstance(self.tokenizer, MBartTokenizer)
assert hparams.src_lang is not None
assert hparams.tgt_lang is not None
self.dataset_kwargs["src_lang"] = hparams.src_lang
Expand Down
1 change: 1 addition & 0 deletions examples/seq2seq/test_seq2seq_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def test_distill_mbart(self):
tgt_lang="ro_RO",
)
model = self._test_distiller_cli(updates, check_contents=False)
assert model.model.config.model_type == "mbart"

ckpts = list(Path(model.output_dir).glob("*.ckpt"))
self.assertEqual(1, len(ckpts))
Expand Down