Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #21 ; add CLI option for --average #22

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 56 additions & 14 deletions bleurt/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""BLEURT scoring library."""
import itertools
import os
from pathlib import Path

from bleurt import checkpoint as checkpoint_lib
from bleurt import encoding
Expand Down Expand Up @@ -47,6 +48,10 @@
flags.DEFINE_integer("bleurt_batch_size", 100,
"Number of sentence pairs per batch.")

flags.DEFINE_bool("average", False, "output mean median score only")

flags.DEFINE_bool("interactive", False, "Launch an interactive shell")


def _get_default_checkpoint():
pkg = os.path.abspath(__file__)
Expand Down Expand Up @@ -284,11 +289,11 @@ def bleurt_ops(references, candidates):
return bleurt_ops


def score_files(reference_file, candidate_file, bleurt_checkpoint):
def score_files(reference_file: Path, candidate_file: Path, bleurt_checkpoint, do_average=False):
"""Computes BLEURT scores from two files on disk."""
assert tf.io.gfile.exists(reference_file), \
assert reference_file.exists(), \
"Reference file {} not found".format(reference_file)
assert tf.io.gfile.exists(candidate_file), \
assert candidate_file.exists(), \
"Candidate file {} not found".format(candidate_file)

ref_buffer = []
Expand All @@ -303,8 +308,9 @@ def _consume_buffer():
scores_buffer.extend(scores)

logging.info("Computing BLEURT scores...")
with tf.io.gfile.GFile(reference_file, "r") as ref_file:
with tf.io.gfile.GFile(candidate_file, "r") as cand_file:

with reference_file.open("r", encoding='utf-8', errors='ignore') as ref_file:
with candidate_file.open("r", encoding='utf-8', errors='ignore') as cand_file:
for ref_sentence, cand_sentence in itertools.zip_longest(
ref_file, cand_file, fillvalue=None):
assert ref_sentence is not None, \
Expand All @@ -321,22 +327,58 @@ def _consume_buffer():
_consume_buffer()
logging.info("BLEURT scores computed.")

if FLAGS.scores_file:
logging.info("Writing to disk.")
with tf.io.gfile.GFile(FLAGS.scores_file, "w+") as score_file:
for s in scores_buffer:
score_file.write("{}\n".format(str(s)))
if do_average:
import numpy as np
scores = np.array(scores_buffer)
print('Mean: %.6f Median: %.6f Total: %d' % (np.mean(scores), np.median(scores), len(scores)))
else:
for s in scores_buffer:
print("{}".format(str(s)))
if FLAGS.scores_file:
logging.info("Writing to disk.")
with tf.io.gfile.GFile(FLAGS.scores_file, "w+") as score_file:
for s in scores_buffer:
score_file.write("{}\n".format(str(s)))
else:
for s in scores_buffer:
print(f"{s:.4f}")
logging.info("Done.")

def score_interactive(bleurt_checkpoint):
"""Computes BLEURT scores from two files on disk."""
import readline
import atexit
readline.parse_and_bind('set editing-mode emacs')
histfile = str(Path('~/.python_history').expanduser())
try:
readline.read_history_file(histfile)
# default history len is -1 (infinite), which may grow unruly
readline.set_history_length(1000)
except FileNotFoundError:
pass
atexit.register(readline.write_history_file, histfile)

scorer = BleurtScorer(bleurt_checkpoint)

while True:
while True:
ref_line = input('Reference: ').strip()
if ref_line:
break
while True:
hyp_line = input('Hypothesis: ').strip()
if hyp_line:
break
score = scorer.score([ref_line], [hyp_line], 1)[0]
print(f'Score: {score:.4f}\n')


def main(_):
if FLAGS.interactive:

return score_interactive(FLAGS.bleurt_checkpoint)
assert FLAGS.reference_file, "Please specify a reference sentences file."
assert FLAGS.candidate_file, "Please specify a reference sentences file."
score_files(FLAGS.reference_file, FLAGS.candidate_file,
FLAGS.bleurt_checkpoint)
score_files(Path(FLAGS.reference_file), Path(FLAGS.candidate_file),
FLAGS.bleurt_checkpoint, do_average=FLAGS.average)


if __name__ == "__main__":
Expand Down