Skip to content

Commit

Permalink
[s2s] distill t5-large -> t5-small (#8376)
Browse files Browse the repository at this point in the history
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
  • Loading branch information
sbhaktha and sshleifer authored Nov 11, 2020
1 parent a5b6823 commit 81ebd70
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 67 deletions.
2 changes: 1 addition & 1 deletion examples/seq2seq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ cp xsum/test* all_pl
then use `all_pl` as DATA in the command above.

#### Direct Knowledge Distillation (KD)
+ In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `BartSummarizationDistiller`.
+ In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `SummarizationDistiller`.
+ This method was used for `sshleifer/distilbart-xsum-12-6`, `6-6`, and `9-6` checkpoints were produced.
+ You must use [`distillation.py`](./distillation.py). Note that this command initializes the student for you.

Expand Down
157 changes: 93 additions & 64 deletions examples/seq2seq/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from lightning_base import generic_train # noqa


class BartSummarizationDistiller(SummarizationModule):
"""Supports Bart, Pegasus and other models that inherit from Bart."""
class SummarizationDistiller(SummarizationModule):
"""Supports T5, Bart, Pegasus and other models that inherit from Bart."""

loss_names = ["loss", "ce_loss", "mlm_loss", "hid_loss_enc", "hid_loss_dec"]

Expand All @@ -40,26 +40,38 @@ def __init__(self, hparams):
hparams.model_name_or_path = str(save_dir) # Tell lightning we are training the student
teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval()
use_task_specific_params(teacher, hparams.task) # We copy good generation parameters to student by default
student, e_layer_ids, d_layer_ids = create_student_by_copying_alternating_layers(
teacher, e=hparams.student_encoder_layers, d=hparams.student_decoder_layers, save_path=save_dir
)
if hparams.student is not None:
student = AutoModelForSeq2SeqLM.from_pretrained(hparams.student)
use_task_specific_params(student, hparams.task)
e_layer_ids, d_layer_ids = None, None
else:
student, e_layer_ids, d_layer_ids = create_student_by_copying_alternating_layers(
teacher, e=hparams.student_encoder_layers, d=hparams.student_decoder_layers, save_path=save_dir
)

if hparams.length_penalty != -1:
student.config.length_penalty = hparams.length_penalty
hparams.tokenizer_name = hparams.teacher # Use teacher's tokenizer
super().__init__(hparams, model=student, config=student.config)
model_type = student.config.model_type
self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int]
assert (
student.config.model_type == teacher.config.model_type
), f"teacher, student model types should be the same, got {student.config.model_type} != {teacher.config.model_type}"

if model_type == "t5":
if student.config.model_type == "t5":
student_encoder_layers = len(student.get_encoder().block)
student_decoder_layers = len(student.get_decoder().block)
teacher_encoder_layers = len(teacher.get_encoder().block)
teacher_decoder_layers = len(teacher.get_decoder().block)
else:
student_encoder_layers = student.config.encoder_layers
student_decoder_layers = student.config.decoder_layers
teacher_encoder_layers = teacher.config.encoder_layers
teacher_decoder_layers = teacher.config.decoder_layers

self.different_encoder = hparams.student_encoder_layers != teacher_encoder_layers
self.different_decoder = hparams.student_decoder_layers != teacher_decoder_layers

self.different_base_models = not (hparams.student is None or hparams.teacher == hparams.student)
self.do_calc_hidden_loss = (not self.different_base_models) and hparams.alpha_hid > 0
self.different_encoder = self.different_base_models or (student_encoder_layers != teacher_encoder_layers)
# self.different_encoder determines whether we need to run the teacher encoder
self.teacher = teacher
freeze_params(self.teacher)

Expand All @@ -68,13 +80,28 @@ def __init__(self, hparams):
del self.teacher.model.encoder
except AttributeError: # T5
del self.teacher.encoder
# Intermediate supervision: Decide which layers to supervise
if hparams.supervise_forward:
self.e_matches = get_layers_to_supervise(n_student=len(self.e_layer_ids), n_teacher=teacher_encoder_layers)
self.d_matches = get_layers_to_supervise(n_student=len(self.d_layer_ids), n_teacher=teacher_decoder_layers)
else: # student layer should emulate hidden states of the teacher layer it was copied from
self.e_matches = self.e_layer_ids
self.d_matches = self.d_layer_ids

if e_layer_ids is None:
e_layer_ids = list(range(student_encoder_layers))
if d_layer_ids is None:
d_layer_ids = list(range(student_decoder_layers))

self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int]

if self.do_calc_hidden_loss: # Intermediate supervision: Decide which layers to supervise
if hparams.supervise_forward:
self.e_matches = get_layers_to_supervise(
n_student=len(self.e_layer_ids), n_teacher=teacher_encoder_layers
)
self.d_matches = get_layers_to_supervise(
n_student=len(self.d_layer_ids), n_teacher=teacher_decoder_layers
)
else: # student layer should emulate hidden states of the teacher layer it was copied from
self.e_matches = self.e_layer_ids
self.d_matches = self.d_layer_ids
else:
self.e_matches = None
self.d_matches = None

self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
self.temperature = 2.0
Expand All @@ -84,22 +111,8 @@ def __init__(self, hparams):
gc.collect()
torch.cuda.empty_cache()

def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor:
"""Supervise MSE(teacher.encoder_outputs, student.encoder_outputs)."""
# raise NotImplementedError()
if mask is not None:
# mask has False at padding_idx
sel_mask = mask[:, :, None].expand_as(student_outputs).bool()
s_logits_slct = torch.masked_select(student_outputs, sel_mask)
t_logits_slct = torch.masked_select(teacher_outputs, sel_mask)
else:
t_logits_slct = teacher_outputs
s_logits_slct = student_outputs
return F.mse_loss(s_logits_slct, t_logits_slct)

def calc_ce_loss(self, mask, s_logits, t_logits):
"""Copy pasted from distillbert (transformers/examples/distillation/)"""

# mask has False at padding_idx
sel_mask = mask[:, :, None].expand_as(s_logits)
vocab_size = s_logits.size(-1)
Expand All @@ -123,8 +136,8 @@ def add_model_specific_args(parser, root_dir):
add_distill_args(parser)
return parser

def _step(self, batch):
# assert is_frozen(self.teacher) copied_decoder_layers
def _step(self, batch: dict) -> tuple:
"""Compute the loss for a batch"""
pad_token_id = self.tokenizer.pad_token_id
input_ids, src_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"]
if isinstance(self.model, T5ForConditionalGeneration):
Expand All @@ -133,14 +146,16 @@ def _step(self, batch):
decoder_input_ids = shift_tokens_right(labels, pad_token_id)

# noinspection PyCallingNonCallable
lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self(
student_outputs = self(
input_ids,
attention_mask=src_mask,
decoder_input_ids=decoder_input_ids,
output_hidden_states=True,
output_hidden_states=self.do_calc_hidden_loss,
output_attentions=False,
use_cache=False,
return_dict=True,
)
lm_logits = student_outputs.logits

# Same cross entropy vs. label smoothing logic as finetune.py
assert lm_logits.shape[-1] == self.model.config.vocab_size
Expand All @@ -149,45 +164,52 @@ def _step(self, batch):
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
else:
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
lprobs = F.log_softmax(lm_logits, dim=-1)
student_lm_loss, _ = label_smoothed_nll_loss(
lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id
)

def zero_tensor():
return torch.tensor(0.0).type_as(student_lm_loss)

teacher_enc_outputs = student_outputs.encoder_last_hidden_state # use this unless self.different_base_models
hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
if self.different_encoder: # compute encoder hidden state loss
with torch.no_grad():
teacher_enc_hid = self.teacher.get_encoder()(
input_ids, attention_mask=src_mask, output_hidden_states=True, return_dict=True
).hidden_states

hid_loss_enc = self.calc_hidden_loss(
src_mask,
enc_hidden_state,
teacher_enc_hid,
self.e_matches,
normalize_hidden=self.hparams.normalize_hidden,
)

with torch.no_grad():
outputs = self.teacher(
all_teacher_encoder_outputs = self.teacher.get_encoder()(
input_ids,
attention_mask=src_mask,
encoder_outputs=(enc_outputs,),
decoder_input_ids=decoder_input_ids,
lm_labels=labels,
output_hidden_states=True,
output_hidden_states=self.do_calc_hidden_loss,
return_dict=True,
)
tlogits, tdec_hidden = outputs.logits, outputs.decoder_hidden_states
if self.different_base_models:
teacher_enc_outputs = all_teacher_encoder_outputs.last_hidden_state
elif self.do_calc_hidden_loss:
hid_loss_enc = self.calc_hidden_loss(
src_mask,
student_outputs.encoder_hidden_states,
all_teacher_encoder_outputs.hidden_states,
self.e_matches,
normalize_hidden=self.hparams.normalize_hidden,
)

teacher_outputs = self.teacher(
input_ids,
attention_mask=src_mask,
encoder_outputs=(teacher_enc_outputs,),
decoder_input_ids=decoder_input_ids,
output_hidden_states=self.do_calc_hidden_loss,
use_cache=False, # since we are not passing labels, never let this default to True
return_dict=True,
)
dec_mask = decoder_input_ids.ne(pad_token_id)
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
if self.alpha_hid > 0: # Intermediate supervision of decoder hidden states
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, teacher_outputs.logits)
if self.do_calc_hidden_loss: # Intermediate supervision of decoder hidden states
hid_loss_dec = self.calc_hidden_loss(
dec_mask, dec_hidden, tdec_hidden, self.d_matches, normalize_hidden=self.hparams.normalize_hidden
dec_mask,
student_outputs.decoder_hidden_states,
teacher_outputs.decoder_hidden_states,
self.d_matches,
normalize_hidden=self.hparams.normalize_hidden,
)

blended_loss = (
Expand All @@ -207,6 +229,7 @@ def calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, no
valid_count = mask.sum() * hidden_states[0].size(-1)
student_states = torch.stack([hidden_states[i] for i in range(len(matches))])
teacher_states = torch.stack([hidden_states_T[j] for j in matches])
assert student_states.shape == teacher_states.shape, f"{student_states.shape} != {teacher_states.shape}"
if normalize_hidden:
student_states = F.layer_norm(student_states, student_states.shape[1:])
teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:])
Expand All @@ -216,10 +239,16 @@ def calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, no


def add_distill_args(parser):
# NOTE: if --student argument was specified and the teacher and student base models
# are different, the models still have to have the same tokenizer, specified by
# --tokenizer_name. So, for example, you can distill from t5_large to t5_small but not
# from bart to t5. This s because if the tokenizers are different, the output space
# for the two models is also different and their logits are not comparable.
parser.add_argument("--teacher", type=str)
parser.add_argument("--alpha_ce", default=0.8, type=float)
parser.add_argument("--alpha_mlm", default=0.2, type=float)
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
parser.add_argument("--student", type=str, required=False)
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
parser.add_argument("--no_teacher", action="store_true", default=False)
Expand All @@ -228,8 +257,8 @@ def add_distill_args(parser):
parser.add_argument("--normalize_hidden", action="store_true", default=False)


class BartTranslationDistiller(BartSummarizationDistiller):
"""Supports Mbart, Marian, other models that inherit from Bart."""
class TranslationDistiller(SummarizationDistiller):
"""Supports T5, mBART, Marian, other models that inherit from Bart."""

mode = "translation"
metric_names = ["bleu"]
Expand Down Expand Up @@ -258,7 +287,7 @@ def create_module(args):
if args.no_teacher:
module_cls = TranslationModule if "translation" in args.task else SummarizationModule
else: # DISTILL WITH TEACHER
module_cls = BartTranslationDistiller if "translation" in args.task else BartSummarizationDistiller
module_cls = TranslationDistiller if "translation" in args.task else SummarizationDistiller
args.setup_cls: str = module_cls.__name__
print(f"using module {args.setup_cls}")
model = module_cls(args)
Expand All @@ -276,7 +305,7 @@ def distill_main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()

distill_main(args)
4 changes: 2 additions & 2 deletions examples/seq2seq/test_bash_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import timeout_decorator
import torch

from distillation import BartSummarizationDistiller, distill_main
from distillation import SummarizationDistiller, distill_main
from finetune import SummarizationModule, main
from transformers import MarianMTModel
from transformers.file_utils import cached_path
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_opus_mt_distill_script(self):
with patch.object(sys, "argv", testargs):
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
# assert args.gpus == gpus THIS BREAKS for multi_gpu

Expand Down
12 changes: 12 additions & 0 deletions examples/seq2seq/test_seq2seq_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
"freeze_encoder": False,
"auto_scale_batch_size": False,
"overwrite_output_dir": False,
"student": None,
}


Expand All @@ -107,6 +108,7 @@ def _dump_articles(path: Path, articles: list):
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
T5_TINY = "patrickvonplaten/t5-tiny-random"
T5_TINIER = "sshleifer/t5-tinier-random"
BART_TINY = "sshleifer/bart-tiny-random"
MBART_TINY = "sshleifer/tiny-mbart"
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
Expand Down Expand Up @@ -239,6 +241,16 @@ def test_distill_t5(self):
)
self._test_distiller_cli(updates)

@require_torch_non_multi_gpu_but_fix_me
def test_distill_different_base_models(self):
updates = dict(
teacher=T5_TINY,
student=T5_TINIER,
model_name_or_path=T5_TINIER,
tokenizer_name=T5_TINIER,
)
self._test_distiller_cli(updates)

def _test_distiller_cli(self, updates, check_contents=True):
default_updates = dict(
label_smoothing=0.0,
Expand Down

0 comments on commit 81ebd70

Please sign in to comment.