Skip to content

Commit

Permalink
s2s distillation uses AutoModelForSeqToSeqLM (huggingface#6761)
Browse files Browse the repository at this point in the history
  • Loading branch information
sshleifer authored and Zigur committed Oct 26, 2020
1 parent 1de7321 commit 66eb9db
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
11 changes: 5 additions & 6 deletions examples/seq2seq/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.nn import functional as F

from lightning_base import generic_train
from transformers import BartConfig, BartForConditionalGeneration, MBartTokenizer, T5Config, T5ForConditionalGeneration
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration


try:
Expand Down Expand Up @@ -74,22 +74,22 @@ def sanity_check_gradients(self):
def pre_init(self, hparams):
self.output_dir = Path(hparams.output_dir)
self.output_dir.mkdir(exist_ok=True)
teacher = BartForConditionalGeneration.from_pretrained(hparams.teacher).eval()
teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval()
student_updates = {
"decoder_layers": hparams.student_decoder_layers,
"encoder_layers": hparams.student_encoder_layers,
}
if hparams.length_penalty != -1:
student_updates["length_penalty"] = hparams.length_penalty
d_layers_to_copy = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
d_layers_to_copy: List = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
hparams.d_layer_to_copy = d_layers_to_copy
hparams.e_layer_to_copy = e_layers_to_copy
kw = teacher.config.to_diff_dict()
kw.update(student_updates)
# Copy weights
student_cfg = BartConfig(**kw)
student = BartForConditionalGeneration(student_cfg)
student_cfg = teacher.config_class(**kw)
student = type(teacher)(student_cfg)
student, _ = init_student(student, teacher)
save_dir = self.output_dir.joinpath("student")
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
Expand Down Expand Up @@ -252,7 +252,6 @@ class BartTranslationDistiller(BartSummarizationDistiller):

def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
assert isinstance(self.tokenizer, MBartTokenizer)
assert hparams.src_lang is not None
assert hparams.tgt_lang is not None
self.dataset_kwargs["src_lang"] = hparams.src_lang
Expand Down
1 change: 1 addition & 0 deletions examples/seq2seq/test_seq2seq_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def test_distill_mbart(self):
tgt_lang="ro_RO",
)
model = self._test_distiller_cli(updates, check_contents=False)
assert model.model.config.model_type == "mbart"

ckpts = list(Path(model.output_dir).glob("*.ckpt"))
self.assertEqual(1, len(ckpts))
Expand Down

0 comments on commit 66eb9db

Please sign in to comment.