Skip to content

Commit

Permalink
Save scores to multiple shards.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 393882034
  • Loading branch information
Hyperparticle authored and Mesh TensorFlow Team committed Sep 10, 2021
1 parent 739fb09 commit c640bd0
Showing 1 changed file with 135 additions and 8 deletions.
143 changes: 135 additions & 8 deletions mesh_tensorflow/transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@

import functools
import itertools
import math
import os
import random
import re
import time

import gin
import gin.tf
Expand All @@ -38,6 +40,7 @@
from mesh_tensorflow.transformer import learning_rate_schedules
from mesh_tensorflow.transformer import transformer
import numpy as np
import pandas as pd
import pkg_resources
import six
import tensorflow.compat.v1 as tf
Expand Down Expand Up @@ -1654,6 +1657,52 @@ def get_sequence_length(tokens, pad_id=0):
return scores


@gin.configurable
def save_scores_to_tfrecords(
results, vocabulary, scores_filename, shard_idx=0, save_ids_only=False):
"""Processes results from scoring examples and saves them to tfrecords files.
Args:
results: list of dictionaries containing the results for each scored
example.
vocabulary: a function that that returns a tf.data.Dataset with examples
containing the string field 'targets' and optionally the field 'inputs'
scores_filename: a string (path of file to write scores to).
shard_idx: an integer indicating the current index of the file for sharding.
save_ids_only: if true, save the ID that is prepended to the inputs.
"""
results = _maybe_add_pretokenized_features(results, vocabulary)
scores = [r.get("scores", 0.0) for r in results]
targets = [r.get("targets_pretokenized", r["targets"]) for r in results]
inputs = [r.get("targets_neg_pretokenized", "") for r in results]

if save_ids_only:
inputs = [r.split(" ", 1)[0] for r in inputs]

table_path = "{}_{}.tfrecord".format(scores_filename, shard_idx)
tf.logging.info("Saving results to {}".format(table_path))

with tf.io.TFRecordWriter(table_path) as file_writer:
for input_, target, score in zip(inputs, targets, scores):
record_bytes = tf.train.Example(
features=tf.train.Features(
feature={
"input":
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[bytes(input_, "utf8")])),
"target":
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[bytes(target, "utf8")])),
"score":
tf.train.Feature(
float_list=tf.train.FloatList(value=[score])),
})).SerializeToString()
file_writer.write(record_bytes)


@gin.configurable
def score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir,
vocabulary, score_postprocess_fn=save_scores,
num_examples=None):
Expand Down Expand Up @@ -1691,6 +1740,70 @@ def score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir,
return score_postprocess_fn(results, vocabulary)


@gin.configurable
def score_with_estimator_lazy(
estimator, input_fn, eval_checkpoint_step, model_dir,
vocabulary, score_postprocess_fn=save_scores_to_tfrecords,
num_examples=None, num_examples_per_shard=10000):
"""Score each example returned by input_fn lazily.
Args:
estimator: a TPUEstimator
input_fn: a function that that returns a tf.data.Dataset with examples
containing the string field 'targets' and optionally the field 'inputs'
eval_checkpoint_step: int, list of ints, or None, see `eval_model`
docstring.
model_dir: string, estimator model_dir
vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
targets_vocabulary) tuple
score_postprocess_fn: a function that takes in model outputs and
post-processes, saves, and returns them.
num_examples: int, the total # of examples being scored, None if unknown
num_examples_per_shard: int, the number of examples per file shard.
Returns:
a list of floats
"""
if num_examples is not None:
num_shards = math.ceil(num_examples / num_examples_per_shard)
else:
num_shards = None
tf.logging.info(
"Scoring {} examples with {} shards at {} examples per shard".format(
num_examples, num_shards, num_examples_per_shard))

checkpoint_path, = get_checkpoint_iterator(
eval_checkpoint_step, model_dir)
result_iter = estimator.predict(input_fn, checkpoint_path=checkpoint_path)

start = time.time()
results = []
shard_idx = 0

for i, result in enumerate(result_iter):
results.append(result)
num_results = len(results)
exceeded_num_examples = num_examples is not None and i >= num_examples

if num_results >= num_examples_per_shard or exceeded_num_examples:
score_postprocess_fn(results, vocabulary, shard_idx=shard_idx)

elapsed = time.time() - start
tf.logging.info(
"Scored {} results in {} s, {} examples/s for shard {}".format(
num_results, elapsed, num_results / elapsed, shard_idx))

results = []
shard_idx += 1
start = time.time()

if exceeded_num_examples:
break

if results:
score_postprocess_fn(results, vocabulary, shard_idx=shard_idx)


def _maybe_add_pretokenized_features(examples, vocabulary):
"""Ensures decoded versions of "inputs" and "targets" exist in each example.
Expand All @@ -1712,9 +1825,17 @@ def _maybe_add_pretokenized_features(examples, vocabulary):
for example in examples:
for feature_name in ["inputs", "targets"]:
pretokenized_feature_name = feature_name + "_pretokenized"
neg_pretokenized_feature_name = feature_name + "_neg_pretokenized"
if feature_name in example and pretokenized_feature_name not in example:
s = vocabulary[feature_name].decode(example[feature_name].tolist())
ids = example[feature_name].tolist()

neg_ids = [abs(i) for i in ids if i < 0]
ids = [i for i in ids if i > 0]

s = vocabulary[feature_name].decode(ids)
example[pretokenized_feature_name] = s
neg_s = vocabulary[feature_name].decode(neg_ids)
example[neg_pretokenized_feature_name] = neg_s

if not added_pretokenized[feature_name]:
added_pretokenized[feature_name] = True
Expand All @@ -1730,7 +1851,8 @@ def score_from_strings(estimator, vocabulary, model_type, batch_size,
sequence_length, model_dir, eval_checkpoint_step,
inputs=gin.REQUIRED, targets=gin.REQUIRED,
score_postprocess_fn=gin.REQUIRED, eos_id=1,
score_eos=True):
score_eos=True,
score_with_estimator_fn=score_with_estimator):
"""Compute log likelihoods per example and write to a text file.
inputs & targets must either be the same length (in lines) or have inputs
Expand Down Expand Up @@ -1761,6 +1883,7 @@ def score_from_strings(estimator, vocabulary, model_type, batch_size,
score_eos: a boolean - whether to score the final eos token of each line
If this is set to false, the scores can be interpreted as prefix
log-likelihoods
score_with_estimator_fn: a function to run scoring with the estimator.
Returns:
a list of floats
"""
Expand Down Expand Up @@ -1806,7 +1929,7 @@ def input_fn(params):
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset.prefetch(tf.data.experimental.AUTOTUNE)

return score_with_estimator(
return score_with_estimator_fn(
estimator, input_fn, eval_checkpoint_step, model_dir,
vocabulary, score_postprocess_fn, len(targets))

Expand All @@ -1815,7 +1938,8 @@ def input_fn(params):
def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
model_dir, eval_checkpoint_step, dataset_split,
score_dataset_fn=None,
score_postprocess_fn=gin.REQUIRED):
score_postprocess_fn=gin.REQUIRED,
score_with_estimator_fn=score_with_estimator):
"""Compute log likelihoods per example and write to a text file.
The function returns a list of floats representing the log-likelihood of the
Expand All @@ -1837,6 +1961,7 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
See `eval_dataset_fn` argument to `eval_model` for details.
score_postprocess_fn: Function that takes in model outputs and
post-processes then returns then.
score_with_estimator_fn: a function to run scoring with the estimator.
Returns:
scores: a list of floats, the log likelihood scores
Expand All @@ -1850,9 +1975,9 @@ def score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
input_fn = _get_combined_dataset_input_fn(
scoring_datasets, batch_size, sequence_length)

return score_with_estimator(
return score_with_estimator_fn(
estimator, input_fn, eval_checkpoint_step, model_dir,
vocabulary, score_postprocess_fn, None)
vocabulary, score_postprocess_fn)


def get_estimator(model_type, vocabulary, mesh_shape,
Expand Down Expand Up @@ -2093,7 +2218,8 @@ def eval_model(estimator,
eval_checkpoint_step,
eval_with_score=False,
output_eval_examples=True,
eval_dir_suffix=None):
eval_dir_suffix=None,
score_with_estimator_fn=score_with_estimator):
"""Eval a Mesh-TF model.
Args:
Expand Down Expand Up @@ -2137,6 +2263,7 @@ def eval_model(estimator,
of the eval examples in plaintext to eval_summary_dir.
eval_dir_suffix: string, if not None then will appended to the
eval_summary_dir.
score_with_estimator_fn: a function to run scoring with the estimator.
"""
if eval_dataset_fn is None:
raise ValueError("Must provide eval_dataset_fn through gin for eval.")
Expand Down Expand Up @@ -2248,7 +2375,7 @@ def eval_model(estimator,
tf.logging.info("Checkpoint path %s" % checkpoint_path)
global_step = int(get_step_from_checkpoint_path(checkpoint_path))
if eval_with_score:
outputs, _ = score_with_estimator(
outputs, _ = score_with_estimator_fn(
estimator, input_fn, global_step, model_dir, vocabulary,
num_examples=sum(len(cex) for cex in cached_examples.values()))
else:
Expand Down

0 comments on commit c640bd0

Please sign in to comment.