diff --git a/examples/seq2seq/distil_marian_enro_teacher.sh b/examples/seq2seq/distil_marian_enro_teacher.sh index 75ef07bc06bdea..5c938a71604e3d 100755 --- a/examples/seq2seq/distil_marian_enro_teacher.sh +++ b/examples/seq2seq/distil_marian_enro_teacher.sh @@ -16,5 +16,5 @@ python distillation.py \ --train_batch_size=$BS --eval_batch_size=$BS \ --tokenizer_name Helsinki-NLP/opus-mt-en-ro \ --warmup_steps 500 --logger_name wandb \ - --fp16_opt_level O1 --task translation --normalize_hidden \ + --fp16_opt_level O1 --task translation --normalize_hidden --num_sanity_val_steps=0 \ "$@" diff --git a/examples/seq2seq/distil_marian_no_teacher.sh b/examples/seq2seq/distil_marian_no_teacher.sh index 66fdda1d17dad3..4a30628149dfed 100755 --- a/examples/seq2seq/distil_marian_no_teacher.sh +++ b/examples/seq2seq/distil_marian_no_teacher.sh @@ -13,5 +13,5 @@ python distillation.py \ --train_batch_size=$BS --eval_batch_size=$BS \ --tokenizer_name $m --model_name_or_path $m \ --warmup_steps 500 --sortish_sampler --logger_name wandb \ - --gpus 1 --fp16_opt_level=O1 --task translation \ + --gpus 1 --fp16_opt_level=O1 --task translation --num_sanity_val_steps=0 \ "$@" diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index ef0445e9007d24..15e99bdb0c0b39 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -11,7 +11,6 @@ import numpy as np import pytorch_lightning as pl import torch -from packaging import version from torch.utils.data import DataLoader from lightning_base import BaseTransformer, add_generic_args, generic_train @@ -94,6 +93,9 @@ def __init__(self, hparams, **kwargs): "val": self.hparams.val_max_target_length, "test": self.hparams.test_max_target_length, } + if self.hparams.sortish_sampler and self.hparams.gpus > 1: + self.hparams.sortish_sampler = False + warnings.warn("ignoring sortish_sampler as it is unsupported on multiple GPUs") assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}" assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}" @@ -114,6 +116,10 @@ def __init__(self, hparams, **kwargs): ) self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1" + if self.hparams.eval_max_gen_length is not None: + self.eval_max_length = self.hparams.eval_max_gen_length + else: + self.eval_max_length = self.model.config.max_length self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric def freeze_embeds(self): @@ -209,12 +215,15 @@ def calc_generative_metrics(self, preds, target) -> Dict: def _generative_step(self, batch: dict) -> dict: t0 = time.time() + + # parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens') generated_ids = self.model.generate( batch["input_ids"], attention_mask=batch["attention_mask"], use_cache=True, decoder_start_token_id=self.decoder_start_token_id, num_beams=self.eval_beams, + max_length=self.eval_max_length, ) gen_time = (time.time() - t0) / batch["input_ids"].shape[0] preds: List[str] = self.ids_to_clean_text(generated_ids) @@ -248,7 +257,7 @@ def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) dataset = self.get_dataset(type_path) sampler = None if self.hparams.sortish_sampler and type_path == "train": - assert self.hparams.gpus <= 1 # TODO: assert earlier + assert self.hparams.gpus <= 1 # this should never break because of the assertion in __init__ sampler = dataset.make_sortish_sampler(batch_size) shuffle = False @@ -321,6 +330,7 @@ def add_model_specific_args(parser, root_dir): parser.add_argument( "--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None] ) + parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens") parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save") parser.add_argument( "--early_stopping_patience", @@ -356,8 +366,6 @@ def main(args, model=None) -> SummarizationModule: model: SummarizationModule = SummarizationModule(args) else: model: SummarizationModule = TranslationModule(args) - if version.parse(torch.__version__) == version.parse("1.6") and args.fp16: - warnings.warn("FP16 only seems to work with torch 1.5+apex") dataset = Path(args.data_dir).name if ( args.logger_name == "default" diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 3f4ff2a31d194f..7acbbd7b5e8f3a 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -34,6 +34,7 @@ "supervise_forward": True, "normalize_hidden": True, "label_smoothing": 0.2, + "eval_max_gen_length": None, "eval_beams": 1, "val_metric": "loss", "save_top_k": 1,