Skip to content

Commit

Permalink
Addressed issue #9 "Installation check error: Expected to be a int64 …
Browse files Browse the repository at this point in the history
…tensor but is a int32."

Thanks to yongchanghao and nguyenquangminh for the suggestions.

PiperOrigin-RevId: 485159929
  • Loading branch information
tsellam committed Oct 31, 2022
1 parent c6f2375 commit cebe7e6
Show file tree
Hide file tree
Showing 13 changed files with 44 additions and 41 deletions.
16 changes: 13 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# BLEURT: a Transfer Learning-Based Metric for Natural Language Generation

BLEURT is an evaluation metric for Natural Language Generation. It takes a pair of sentences as input, a *reference* and a *candidate*, and it returns a score that indicates to what extent the candidate is fluent and conveys the mearning of the reference. It is comparable to [`sentence-BLEU`](https://en.wikipedia.org/wiki/BLEU), [`BERTscore`](https://arxiv.org/abs/1904.09675), and [`COMET`](https://github.com/Unbabel/COMET).
BLEURT is an evaluation metric for Natural Language Generation. It takes a pair of sentences as input, a *reference* and a *candidate*, and it returns a score that indicates to what extent the candidate is fluent and conveys the meaning of the reference. It is comparable to [`sentence-BLEU`](https://en.wikipedia.org/wiki/BLEU), [`BERTscore`](https://arxiv.org/abs/1904.09675), and [`COMET`](https://github.com/Unbabel/COMET).

BLEURT is a *trained metric*, that is, it is a regression model trained on ratings data. The model is based on [`BERT`](https://arxiv.org/abs/1810.04805) and [`RemBERT`](https://arxiv.org/pdf/2010.12821.pdf). This repository contains all the code necessary to use it and/or fine-tune it for your own applications. BLEURT uses Tensorflow, and it benefits greatly from modern GPUs (it runs on CPU too).

Expand Down Expand Up @@ -100,7 +100,7 @@ candidates = ["This is the test."]
scorer = score.BleurtScorer(checkpoint)
scores = scorer.score(references=references, candidates=candidates)
assert type(scores) == list and len(scores) == 1
assert isinstance(scores, list) and len(scores) == 1
print(scores)
```
Here again, BLEURT will default to `BERT-Tiny` if no checkpoint is specified.
Expand Down Expand Up @@ -206,7 +206,7 @@ python -m bleurt.score_files \
## Reproducibility

You may find information about how to work with ratings from the [WMT Metrics Shared Task](http://www.statmt.org/wmt19/metrics-task.html), reproduce results
from [our ACL paper](https://arxiv.org/abs/2004.04696), and a selection of models from [our EMNLP paper](http://arxiv.org/abs/2110.06341) [here](https://github.com/google-research/bleurt/blob/master/wmt_experiments.md).
from [our ACL paper](https://arxiv.org/abs/2004.04696), and a selection of models from [our EMNLP paper](http://arxiv.org/abs/2110.06341) on [this page](https://github.com/google-research/bleurt/blob/master/wmt_experiments.md).


## How to Cite
Expand All @@ -221,3 +221,13 @@ Please cite our ACL paper:
booktitle = {Proceedings of ACL}
}
```

The latest model, BLEURT-20, is based on work that led to this follow-up paper:
```
@inproceedings{pu2021learning,
title = {Learning compact metrics for MT},
author = {Pu, Amy and Chung, Hyung Won and Parikh, Ankur P and Gehrmann, Sebastian and Sellam, Thibault},
booktitle = {Proceedings of EMNLP},
year = {2021}
}
```
1 change: 0 additions & 1 deletion bleurt/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Utils to read and write from BLEURT checkpoints."""

import json
Expand Down
1 change: 0 additions & 1 deletion bleurt/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Data tokenization, encoding and serialization library."""
import collections

Expand Down
1 change: 0 additions & 1 deletion bleurt/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
r"""Fine-tunes a BERT/BLEURT checkpoint."""
import os

Expand Down
11 changes: 6 additions & 5 deletions bleurt/lib/experiment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from absl import flags

import tensorflow.compat.v1 as tf
from tensorflow.compat.v1 import estimator as tf_estimator

flags.DEFINE_integer("batch_size", 16, "Batch size.")

Expand Down Expand Up @@ -59,18 +60,18 @@ def run_experiment(model_fn,
additional_eval_specs=None,
exporters=None):
"""Run experiment."""
run_config = tf.estimator.RunConfig(
run_config = tf_estimator.RunConfig(
model_dir=FLAGS.model_dir,
tf_random_seed=FLAGS.tf_random_seed,
save_checkpoints_steps=FLAGS.save_checkpoints_steps,
keep_checkpoint_max=FLAGS.keep_checkpoint_max)
estimator = tf.estimator.Estimator(
estimator = tf_estimator.Estimator(
config=run_config, model_fn=model_fn, model_dir=FLAGS.model_dir)
train_spec = tf.estimator.TrainSpec(
train_spec = tf_estimator.TrainSpec(
input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
assert not additional_eval_specs, (
"Multiple eval sets are not supported with default experiment runner.")
eval_spec = tf.estimator.EvalSpec(
eval_spec = tf_estimator.EvalSpec(
name="default",
input_fn=eval_input_fn,
exporters=exporters,
Expand All @@ -79,5 +80,5 @@ def run_experiment(model_fn,
steps=FLAGS.num_eval_steps)

tf.logging.set_verbosity(tf.logging.INFO)
tf.estimator.train_and_evaluate(
tf_estimator.train_and_evaluate(
estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
28 changes: 14 additions & 14 deletions bleurt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""BLEURT's Tensorflow ops."""

from bleurt import checkpoint as checkpoint_lib
Expand All @@ -21,6 +20,7 @@
from scipy import stats
import tensorflow.compat.v1 as tf

from tensorflow.compat.v1 import estimator as tf_estimator
from tf_slim import metrics
from bleurt.lib import modeling

Expand Down Expand Up @@ -173,19 +173,19 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]

if mode != tf.estimator.ModeKeys.PREDICT:
if mode != tf_estimator.ModeKeys.PREDICT:
scores = features["score"]
else:
scores = tf.zeros(tf.shape(input_ids)[0])

is_training = (mode == tf.estimator.ModeKeys.TRAIN)
is_training = (mode == tf_estimator.ModeKeys.TRAIN)
total_loss, per_example_loss, pred = create_model(
bert_config, is_training, input_ids, input_mask, segment_ids, scores,
use_one_hot_embeddings, n_hidden_layers, hidden_layers_width,
dropout_rate)

output_spec = None
if mode == tf.estimator.ModeKeys.TRAIN:
if mode == tf_estimator.ModeKeys.TRAIN:

# Loads pretrained model
logging.info("**** Initializing from {} ****".format(init_checkpoint))
Expand Down Expand Up @@ -216,30 +216,30 @@ def tpu_scaffold():
num_warmup_steps, use_tpu)

if use_tpu:
output_spec = tf.estimator.tpu.TPUEstimatorSpec(
output_spec = tf_estimator.tpu.TPUEstimatorSpec(
mode=mode,
loss=total_loss,
train_op=train_op,
scaffold_fn=scaffold_fn)

else:
output_spec = tf.estimator.EstimatorSpec(
output_spec = tf_estimator.EstimatorSpec(
mode=mode, loss=total_loss, train_op=train_op)

elif mode == tf.estimator.ModeKeys.EVAL:
elif mode == tf_estimator.ModeKeys.EVAL:

if use_tpu:
eval_metrics = (metric_fn, [per_example_loss, pred, scores])
output_spec = tf.estimator.TPUEstimatorSpec(
output_spec = tf_estimator.TPUEstimatorSpec(
mode=mode, loss=total_loss, eval_metric=eval_metrics)
else:
output_spec = tf.estimator.EstimatorSpec(
output_spec = tf_estimator.EstimatorSpec(
mode=mode,
loss=total_loss,
eval_metric_ops=metric_fn(per_example_loss, pred, scores))

elif mode == tf.estimator.ModeKeys.PREDICT:
output_spec = tf.estimator.EstimatorSpec(
elif mode == tf_estimator.ModeKeys.PREDICT:
output_spec = tf_estimator.EstimatorSpec(
mode=mode, predictions={"predictions": pred})

return output_spec
Expand Down Expand Up @@ -425,7 +425,7 @@ def _serving_input_fn_builder(seq_length):
"input_mask": tf.placeholder(tf.int64, shape=[None, seq_length]),
"segment_ids": tf.placeholder(tf.int64, shape=[None, seq_length])
}
return tf.estimator.export.build_raw_serving_input_receiver_fn(
return tf_estimator.export.build_raw_serving_input_receiver_fn(
name_to_features)


Expand Down Expand Up @@ -484,7 +484,7 @@ def run_finetuning(train_tfrecord,
eval_name = multi_eval_names[i] if multi_eval_names and len(
multi_eval_names) > i else "eval_%s" % i
additional_eval_specs.append(
tf.estimator.EvalSpec(
tf_estimator.EvalSpec(
name=eval_name,
input_fn=additional_dev_input_fn,
steps=FLAGS.num_eval_steps))
Expand All @@ -511,7 +511,7 @@ def run_finetuning(train_tfrecord,

logging.info("Creating TF Estimator.")
exporters = [
tf.estimator.BestExporter(
tf_estimator.BestExporter(
"bleurt_best",
serving_input_receiver_fn=_serving_input_fn_builder(max_seq_length),
event_file_pattern="eval_default/*.tfevents.*",
Expand Down
9 changes: 4 additions & 5 deletions bleurt/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""BLEURT scoring library."""

import os
Expand Down Expand Up @@ -65,10 +64,10 @@ def initialize(self):

def predict(self, input_dict):
predictions = self._bleurt_model_ops(
input_ids=tf.constant(input_dict["input_ids"]),
input_mask=tf.constant(input_dict["input_mask"]),
segment_ids=tf.constant(
input_dict["segment_ids"]))["predictions"].numpy()
input_ids=tf.constant(input_dict["input_ids"], dtype=tf.int64),
input_mask=tf.constant(input_dict["input_mask"], dtype=tf.int64),
segment_ids=tf.constant(input_dict["segment_ids"],
dtype=tf.int64))["predictions"].numpy()
return predictions


Expand Down
1 change: 0 additions & 1 deletion bleurt/score_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""BLEURT scoring library."""

import itertools
Expand Down
1 change: 0 additions & 1 deletion bleurt/wmt/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Downloads WMT data, runs BLEURT, and compute correlation with human ratings."""

import os
Expand Down
1 change: 0 additions & 1 deletion bleurt/wmt/db_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
r"""Library to build download and aggregate WMT ratings data.
More info about the datasets: https://www.statmt.org/wmt19/metrics-task.html
Expand Down
11 changes: 5 additions & 6 deletions bleurt/wmt/downloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Downloads ratings data from the WMT Metrics shared task.
More info about the datasets: https://www.statmt.org/wmt19/metrics-task.html
Expand Down Expand Up @@ -44,33 +43,33 @@
WMT_LOCATIONS = {
2015: {
"eval_data": ("DAseg-wmt-newstest2015", "DAseg-wmt-newstest2015.tar.gz",
"http://www.computing.dcu.ie/~ygraham/")
"https://www.scss.tcd.ie/~ygraham/")
},
2016: {
"eval_data": ("DAseg-wmt-newstest2016", "DAseg-wmt-newstest2016.tar.gz",
"http://www.computing.dcu.ie/~ygraham/")
"https://www.scss.tcd.ie/~ygraham/")
},
2017: {
"submissions":
("wmt17-metrics-task-no-hybrids", "wmt17-metrics-task-package.tgz",
"http://ufallab.ms.mff.cuni.cz/~bojar/"),
"eval_data": ("newstest2017-segment-level-human",
"newstest2017-segment-level-human.tar.gz",
"http://computing.dcu.ie/~ygraham/")
"https://www.scss.tcd.ie/~ygraham/")
},
2018: {
"submissions":
("wmt18-metrics-task-nohybrids", "wmt18-metrics-task-nohybrids.tgz",
"http://ufallab.ms.mff.cuni.cz/~bojar/wmt18/"),
"eval_data": ("newstest2018-humaneval", "newstest2018-humaneval.tar.gz",
"http://computing.dcu.ie/~ygraham/")
"https://www.scss.tcd.ie/~ygraham/")
},
2019: {
"submissions": ("wmt19-submitted-data-v3",
"wmt19-submitted-data-v3-txt-minimal.tgz",
"http://ufallab.ms.mff.cuni.cz/~bojar/wmt19/"),
"eval_data": ("newstest2019-humaneval", "newstest2019-humaneval.tar.gz",
"https://www.computing.dcu.ie/~ygraham/")
"https://www.scss.tcd.ie/~ygraham/")
}
}

Expand Down
1 change: 0 additions & 1 deletion bleurt/wmt/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Computes correlation betweem BLEURT and human ratings on a test file from WMT."""
import collections
import json
Expand Down
3 changes: 2 additions & 1 deletion checkpoints.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ The column `max #tokens` specifies the size of BLEURT's input. Internally, the m
#### From an existing BLEURT checkpoint

BLEURT offers a command-line tool to fine-tune checkpoints on a custom set of ratings.
Currently, we only support fine-tuning the previous generation of checkpoints, based on English BERT (discussed in the previous section).
To illustrate, the following command fine-tunes BERT-tiny on a toy set of examples:

```
Expand All @@ -66,7 +67,7 @@ python -m bleurt.finetune \
-dev_set=bleurt/test_data/ratings_dev.jsonl \
-num_train_steps=500
```
You may open the files `test_data/ratings_*.jsonl` for example of how the files should be formattted.
You may open the files `test_data/ratings_*.jsonl` for example of how the files should be formatted.
Internally, the script tokenizes the JSON sentences, it serializes them into TFRecord files,
and it runs a train/eval loop. It saves the best model and exports it as a BLEURT checkpoint.

Expand Down

0 comments on commit cebe7e6

Please sign in to comment.