Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scores for generated text in inference mode #164

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 52 additions & 11 deletions mesh_tensorflow/transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,16 +491,9 @@ def _maybe_detokenize(ids, vocab):
compute_loss=False,
mode=mode,
variable_dtype=get_variable_dtype())
batch_dim, length_dim, vocab_dim = logits.shape.dims
cross_entropy = mtf.layers.softmax_cross_entropy_with_logits(
logits, mtf_features["targets"], vocab_dim)
cross_entropy *= mtf.cast(
mtf.not_equal(targets, 0), cross_entropy.dtype)
if model_type == "delimited_lm":
cross_entropy *= mtf.cast(mtf.logical_not(
transformer.delimited_lm_inputs_mask(targets)), cross_entropy.dtype)
scores = -mtf.reduce_sum(cross_entropy, reduced_dim=length_dim)
scores = mtf.anonymize(scores)

# calculate log likelihood
scores = compute_scores(logits, targets, model_type)
targets = mtf.anonymize(targets)
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)
targets = clean_decodes(lowering.export_to_tf_tensor(targets))
Expand Down Expand Up @@ -531,18 +524,39 @@ def _maybe_detokenize(ids, vocab):
inputs, variable_dtype=get_variable_dtype())
else:
raise ValueError("unrecognized class")

# calculate probabilities for the output texts
# Replaces everything after EOS with 0 (along last dim).
eos_and_after = mtf.cumsum(mtf.cast(mtf.equal(mtf_samples, 1), tf.int32),
exclusive=True, dim=mtf_samples.shape[1])
valid_ids = mtf.equal(eos_and_after, 0)
targets_for_score = mtf.where(valid_ids, mtf_samples, 0)

logits, _ = transformer_model.call_simple(
inputs=inputs,
targets=targets_for_score,
compute_loss=False,
mode='score',
variable_dtype=get_variable_dtype())

# calculate log likelihood
scores = compute_scores(logits, targets_for_score, model_type)

mtf_samples = mtf.anonymize(mtf_samples)
inputs = mtf.anonymize(inputs)
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)
inputs = clean_decodes(lowering.export_to_tf_tensor(inputs))
outputs = clean_decodes(lowering.export_to_tf_tensor(mtf_samples))
scores = lowering.export_to_tf_tensor(scores)

inputs = _maybe_detokenize(inputs, inputs_vocabulary(vocabulary))
outputs = _maybe_detokenize(outputs, targets_vocabulary(vocabulary))

predictions = {
"inputs": inputs,
"outputs": outputs}
"outputs": outputs,
"scores": scores
}

if mode in ["score", tf.estimator.ModeKeys.PREDICT]:
# When exporting a model, we need to communicate to TF-Serving that
Expand Down Expand Up @@ -1210,6 +1224,33 @@ def clean_decodes(ids, eos_id=1, pad_id=0, length_axis=-1):
return tf.where_v2(valid_ids, ids, pad_id)


def compute_scores(logits, targets, model_type):
"""Compute the log likelihood given logits and targets.

Args:
logits: A mtf Tensor with floating-point dtype, containing the predicted
relative log probabilities of the classes.
targets: A mtf Tensor with integer dtype whose values are in the range
[0, vocab_dim.size).
model_type: a string. One of "bitransformer", "lm", "delimited_lm",
"aligned", or "bi_teacher_student"

Returns:
a float mtf.Tensor with the log likelihood.
"""
batch_dim, length_dim, vocab_dim = logits.shape.dims
cross_entropy = mtf.layers.softmax_cross_entropy_with_logits(
logits, targets, vocab_dim)
cross_entropy *= mtf.cast(
mtf.not_equal(targets, 0), cross_entropy.dtype)
if model_type == "delimited_lm":
cross_entropy *= mtf.cast(mtf.logical_not(
transformer.delimited_lm_inputs_mask(targets)), cross_entropy.dtype)
scores = -mtf.reduce_sum(cross_entropy, reduced_dim=length_dim)
scores = mtf.anonymize(scores)
return scores


def _score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir,
scores_filename, num_examples=None):
"""For each example returned by input_fn, compute log likelihood.
Expand Down