Skip to content

Commit

Permalink
Revert "[s2s] --eval_max_generate_length (huggingface#7018)"
Browse files Browse the repository at this point in the history
This reverts commit 6090cc9.
  • Loading branch information
fabiocapsouza authored Nov 15, 2020
1 parent bcb9439 commit 6f4ef33
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 15 deletions.
2 changes: 1 addition & 1 deletion examples/seq2seq/distil_marian_enro_teacher.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ python distillation.py \
--train_batch_size=$BS --eval_batch_size=$BS \
--tokenizer_name Helsinki-NLP/opus-mt-en-ro \
--warmup_steps 500 --logger_name wandb \
--fp16_opt_level O1 --task translation --normalize_hidden --num_sanity_val_steps=0 \
--fp16_opt_level O1 --task translation --normalize_hidden \
"$@"
2 changes: 1 addition & 1 deletion examples/seq2seq/distil_marian_no_teacher.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ python distillation.py \
--train_batch_size=$BS --eval_batch_size=$BS \
--tokenizer_name $m --model_name_or_path $m \
--warmup_steps 500 --sortish_sampler --logger_name wandb \
--gpus 1 --fp16_opt_level=O1 --task translation --num_sanity_val_steps=0 \
--gpus 1 --fp16_opt_level=O1 --task translation \
"$@"
16 changes: 4 additions & 12 deletions examples/seq2seq/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
import pytorch_lightning as pl
import torch
from packaging import version
from torch.utils.data import DataLoader

from lightning_base import BaseTransformer, add_generic_args, generic_train
Expand Down Expand Up @@ -93,9 +94,6 @@ def __init__(self, hparams, **kwargs):
"val": self.hparams.val_max_target_length,
"test": self.hparams.test_max_target_length,
}
if self.hparams.sortish_sampler and self.hparams.gpus > 1:
self.hparams.sortish_sampler = False
warnings.warn("ignoring sortish_sampler as it is unsupported on multiple GPUs")
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"

Expand All @@ -116,10 +114,6 @@ def __init__(self, hparams, **kwargs):
)
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
assert self.eval_beams >= 1, f"got self.eval_beams={self.eval_beams}. Need an integer > 1"
if self.hparams.eval_max_gen_length is not None:
self.eval_max_length = self.hparams.eval_max_gen_length
else:
self.eval_max_length = self.model.config.max_length
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric

def freeze_embeds(self):
Expand Down Expand Up @@ -215,15 +209,12 @@ def calc_generative_metrics(self, preds, target) -> Dict:

def _generative_step(self, batch: dict) -> dict:
t0 = time.time()

# parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
generated_ids = self.model.generate(
batch["input_ids"],
attention_mask=batch["attention_mask"],
use_cache=True,
decoder_start_token_id=self.decoder_start_token_id,
num_beams=self.eval_beams,
max_length=self.eval_max_length,
)
gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
preds: List[str] = self.ids_to_clean_text(generated_ids)
Expand Down Expand Up @@ -257,7 +248,7 @@ def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False)
dataset = self.get_dataset(type_path)
sampler = None
if self.hparams.sortish_sampler and type_path == "train":
assert self.hparams.gpus <= 1 # this should never break because of the assertion in __init__
assert self.hparams.gpus <= 1 # TODO: assert earlier
sampler = dataset.make_sortish_sampler(batch_size)
shuffle = False

Expand Down Expand Up @@ -330,7 +321,6 @@ def add_model_specific_args(parser, root_dir):
parser.add_argument(
"--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
)
parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
parser.add_argument(
"--early_stopping_patience",
Expand Down Expand Up @@ -366,6 +356,8 @@ def main(args, model=None) -> SummarizationModule:
model: SummarizationModule = SummarizationModule(args)
else:
model: SummarizationModule = TranslationModule(args)
if version.parse(torch.__version__) == version.parse("1.6") and args.fp16:
warnings.warn("FP16 only seems to work with torch 1.5+apex")
dataset = Path(args.data_dir).name
if (
args.logger_name == "default"
Expand Down
1 change: 0 additions & 1 deletion examples/seq2seq/test_seq2seq_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
"supervise_forward": True,
"normalize_hidden": True,
"label_smoothing": 0.2,
"eval_max_gen_length": None,
"eval_beams": 1,
"val_metric": "loss",
"save_top_k": 1,
Expand Down

0 comments on commit 6f4ef33

Please sign in to comment.