Skip to content

Commit

Permalink
[s2s] run_eval.py QOL improvements and cleanup(#6746)
Browse files Browse the repository at this point in the history
  • Loading branch information
sshleifer authored Aug 26, 2020
1 parent 434936f commit 61518e2
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 22 deletions.
58 changes: 38 additions & 20 deletions examples/seq2seq/run_eval.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import argparse
import json
import time
import warnings
from logging import getLogger
from pathlib import Path
from typing import Dict, List

import torch
from tqdm import tqdm

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer


logger = getLogger(__name__)

try:
from .utils import calculate_bleu, calculate_rouge, trim_batch, use_task_specific_params
from .utils import calculate_bleu, calculate_rouge, use_task_specific_params
except ImportError:
from utils import calculate_bleu, calculate_rouge, trim_batch, use_task_specific_params
from utils import calculate_bleu, calculate_rouge, use_task_specific_params

DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

Expand All @@ -23,44 +29,47 @@ def chunks(lst, n):


def generate_summaries_or_translations(
examples: list,
examples: List[str],
out_file: str,
model_name: str,
batch_size: int = 8,
device: str = DEFAULT_DEVICE,
fp16=False,
task="summarization",
decoder_start_token_id=None,
**gen_kwargs,
) -> None:
**generate_kwargs,
) -> Dict:
"""Save model.generate results to <out_file>, and return how long it took."""
fout = Path(out_file).open("w", encoding="utf-8")
model_name = str(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
if fp16:
model = model.half()
if decoder_start_token_id is None:
decoder_start_token_id = gen_kwargs.pop("decoder_start_token_id", None)

tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type.

# update config with summarization specific params
start_time = time.time()
# update config with task specific params
use_task_specific_params(model, task)

for batch in tqdm(list(chunks(examples, batch_size))):
for examples_chunk in tqdm(list(chunks(examples, batch_size))):
if "t5" in model_name:
batch = [model.config.prefix + text for text in batch]
batch = tokenizer(batch, return_tensors="pt", truncation=True, padding="max_length").to(device)
input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id)
examples_chunk = [model.config.prefix + text for text in examples_chunk]
batch = tokenizer(examples_chunk, return_tensors="pt", truncation=True, padding="longest").to(device)
summaries = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
input_ids=batch.input_ids,
attention_mask=batch.attention_mask,
decoder_start_token_id=decoder_start_token_id,
**gen_kwargs,
**generate_kwargs,
)
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for hypothesis in dec:
fout.write(hypothesis + "\n")
fout.flush()
fout.close()
runtime = time.time() - start_time
n_obs = len(examples)
return dict(n_obs=n_obs, runtime=runtime, seconds_per_sample=round(runtime / n_obs, 4))


def run_generate():
Expand All @@ -70,7 +79,13 @@ def run_generate():
parser.add_argument("save_path", type=str, help="where to save summaries")

parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test_reference_summaries.txt")
parser.add_argument("--score_path", type=str, required=False, help="where to save the rouge score in json format")
parser.add_argument(
"--score_path",
type=str,
required=False,
default="metrics.json",
help="where to save the rouge score in json format",
)
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
parser.add_argument("--task", type=str, default="summarization", help="typically translation or summarization")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
Expand All @@ -79,7 +94,7 @@ def run_generate():
type=int,
default=None,
required=False,
help="decoder_start_token_id (otherwise will look at config)",
help="Defaults to using config",
)
parser.add_argument(
"--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all."
Expand All @@ -90,7 +105,9 @@ def run_generate():
if args.n_obs > 0:
examples = examples[: args.n_obs]
Path(args.save_path).parent.mkdir(exist_ok=True)
generate_summaries_or_translations(
if args.reference_path is None and Path(args.score_path).exists():
warnings.warn(f"score_path {args.score_path} will be overwritten unless you type ctrl-c.")
runtime_metrics = generate_summaries_or_translations(
examples,
args.save_path,
args.model_name,
Expand All @@ -107,9 +124,10 @@ def run_generate():
output_lns = [x.rstrip() for x in open(args.save_path).readlines()]
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)]
scores: dict = score_fn(output_lns, reference_lns)
scores.update(runtime_metrics)
print(scores)
if args.score_path is not None:
json.dump(scores, open(args.score_path, "w+"))
json.dump(scores, open(args.score_path, "w"))
return scores


Expand Down
15 changes: 13 additions & 2 deletions examples/seq2seq/test_seq2seq_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,24 @@ def _test_distiller_cli(self, updates, check_contents=True):


@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
def test_run_eval_bart(model):
def test_run_eval(model):
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
output_file_name = input_file_name.parent / "utest_output.txt"
assert not output_file_name.exists()
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
_dump_articles(input_file_name, articles)
testargs = ["run_eval.py", model, str(input_file_name), str(output_file_name)] # TODO: test score_path
score_path = str(Path(tempfile.mkdtemp()) / "scores.json")
task = "translation_en_to_de" if model == T5_TINY else "summarization"
testargs = [
"run_eval.py",
model,
str(input_file_name),
str(output_file_name),
"--score_path",
score_path,
"--task",
task,
]
with patch.object(sys, "argv", testargs):
run_generate()
assert Path(output_file_name).exists()
Expand Down

0 comments on commit 61518e2

Please sign in to comment.