Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[s2s] --eval_max_generate_length #7018

Merged
merged 4 commits into from
Sep 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/seq2seq/distil_marian_enro_teacher.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ python distillation.py \
--train_batch_size=$BS --eval_batch_size=$BS \
--tokenizer_name Helsinki-NLP/opus-mt-en-ro \
--warmup_steps 500 --logger_name wandb \
--fp16_opt_level O1 --task translation --normalize_hidden \
--fp16_opt_level O1 --task translation --normalize_hidden --num_sanity_val_steps=0 \
"$@"
2 changes: 1 addition & 1 deletion examples/seq2seq/distil_marian_no_teacher.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ python distillation.py \
--train_batch_size=$BS --eval_batch_size=$BS \
--tokenizer_name $m --model_name_or_path $m \
--warmup_steps 500 --sortish_sampler --logger_name wandb \
--gpus 1 --fp16_opt_level=O1 --task translation \
--gpus 1 --fp16_opt_level=O1 --task translation --num_sanity_val_steps=0 \
"$@"
16 changes: 12 additions & 4 deletions examples/seq2seq/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import numpy as np
import pytorch_lightning as pl
import torch
from packaging import version
from torch.utils.data import DataLoader

from lightning_base import BaseTransformer, add_generic_args, generic_train
Expand Down Expand Up @@ -94,6 +93,9 @@ def __init__(self, hparams, **kwargs):
"val": self.hparams.val_max_target_length,
"test": self.hparams.test_max_target_length,
}
if self.hparams.sortish_sampler and self.hparams.gpus > 1:
self.hparams.sortish_sampler = False
warnings.warn("ignoring sortish_sampler as it is unsupported on multiple GPUs")
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"

Expand All @@ -114,6 +116,10 @@ def __init__(self, hparams, **kwargs):
)
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
if self.hparams.eval_max_gen_length is not None:
self.eval_max_length = self.hparams.eval_max_gen_length
else:
self.eval_max_length = self.model.config.max_length
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric

def freeze_embeds(self):
Expand Down Expand Up @@ -209,12 +215,15 @@ def calc_generative_metrics(self, preds, target) -> Dict:

def _generative_step(self, batch: dict) -> dict:
t0 = time.time()

# parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
generated_ids = self.model.generate(
batch["input_ids"],
attention_mask=batch["attention_mask"],
use_cache=True,
decoder_start_token_id=self.decoder_start_token_id,
num_beams=self.eval_beams,
max_length=self.eval_max_length,
)
gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
preds: List[str] = self.ids_to_clean_text(generated_ids)
Expand Down Expand Up @@ -248,7 +257,7 @@ def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False)
dataset = self.get_dataset(type_path)
sampler = None
if self.hparams.sortish_sampler and type_path == "train":
assert self.hparams.gpus <= 1 # TODO: assert earlier
assert self.hparams.gpus <= 1 # this should never break because of the assertion in __init__
sampler = dataset.make_sortish_sampler(batch_size)
shuffle = False

Expand Down Expand Up @@ -321,6 +330,7 @@ def add_model_specific_args(parser, root_dir):
parser.add_argument(
"--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
)
parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
parser.add_argument(
"--early_stopping_patience",
Expand Down Expand Up @@ -356,8 +366,6 @@ def main(args, model=None) -> SummarizationModule:
model: SummarizationModule = SummarizationModule(args)
else:
model: SummarizationModule = TranslationModule(args)
if version.parse(torch.__version__) == version.parse("1.6") and args.fp16:
warnings.warn("FP16 only seems to work with torch 1.5+apex")
dataset = Path(args.data_dir).name
if (
args.logger_name == "default"
Expand Down
1 change: 1 addition & 0 deletions examples/seq2seq/test_seq2seq_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"supervise_forward": True,
"normalize_hidden": True,
"label_smoothing": 0.2,
"eval_max_gen_length": None,
"eval_beams": 1,
"val_metric": "loss",
"save_top_k": 1,
Expand Down