Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to get a probability for the generated text? #311

Closed
allen-q opened this issue Jul 23, 2020 · 22 comments
Closed

How to get a probability for the generated text? #311

allen-q opened this issue Jul 23, 2020 · 22 comments

Comments

@allen-q
Copy link

allen-q commented Jul 23, 2020

Hi there,

Thanks for releasing the code and examples for such a great model.

I've followed the example and been able to train a model using my own dataset and everything is working now.

I have one question though, is it possible to get a probability score for the generated text? When I run the the model in prediction mode, I can only see the inputs and outputs text.

imported.signatures["serving_default"](tf.constant(["trivia question: What's the highest mountain in the world?"]))
 'inputs': <tf.Tensor: shape=(10,), dtype=string, numpy=
 array([b"trivia question: What's the highest mountain in the world?", b'',
        b'', b'', b'', b'', b'', b'', b'', b''], dtype=object)>,
 'outputs': <tf.Tensor: shape=(10,), dtype=string, numpy=
 array([b'travel - national', b'', b'', b'', b'', b'', b'', b'', b'', b''],
       dtype=object)>}

Thanks,
Allen

@craffel
Copy link
Collaborator

craffel commented Jul 24, 2020

You can feed the generated text back into the model with the score function https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/models/mtf_model.py#L338

We have been considering adding functionality to provide the log-likelihood alongside the generated samples but are not currently working on it. A PR would be welcome.

@craffel craffel closed this as completed Jul 24, 2020
@allen-q
Copy link
Author

allen-q commented Jul 24, 2020

Thanks for the reply, I will take a look to see if I can get it working.

@allen-q
Copy link
Author

allen-q commented Jul 25, 2020

Hi @craffel , I've got the scores by using the score function. However, I can only call this function from the model directly. Once the model is exported and loaded back using tf.saved_model.load, there's only the 'serving_default' signature available which just gives me the generated texts.

I spent quite a bit time to understand how the model export works and I think the serving_default signature is linked to this concrete function: https://github.com/tensorflow/mesh/blob/efcd89f66d53ffddc5d13d37041d774d1bf7d210/mesh_tensorflow/transformer/utils.py#L500

I was thinking about change the predict logic to add the scores to the predictions dict like below:

  predictions = {
      "inputs": inputs,
      "outputs": outputs,
      "scores": scores
  }

However, to calculate scores, it needs the 'targets' feature which is not passed to the function when it's in predict mode.

https://github.com/tensorflow/mesh/blob/efcd89f66d53ffddc5d13d37041d774d1bf7d210/mesh_tensorflow/transformer/utils.py#L468

I feel the outputs generated by the predict logic could be used as the 'targets' feature but I've been unable to make it work.

Can you please let me know if I'm heading to the right direction and if yes, how can I get the 'targets' to be able to calculate the scores?

Thank you.

@craffel
Copy link
Collaborator

craffel commented Jul 25, 2020

I see, I didn't realize you were using an exported model. To address your issue I think the best solution would be to have the model generate log-likelihoods alongside samples in predict mode, rather than call score separately. We have been discussing doing this but have no immediate plans to do so.

@allen-q
Copy link
Author

allen-q commented Jul 26, 2020

Hi @craffel , yes, we need to run this model for real time predictions with a probability score(or some kine of a confidence score for the prediction). At the moment, I'm planning to run an exported model via TF Serving. Is this the right way to use the model?

I will try to get the score included in the predict function as this will be a must to have feature for us to use this model. Can you please share some ideas or pseudo code about how this can be done? I'm happy to make a PR if I can get it working. Thank you.

@craffel
Copy link
Collaborator

craffel commented Jul 26, 2020

You should modify the sample_autoregressive and decode functions from the Mesh TensorFlow Transformer so that they return both the (sampled) predictions and the log-likelihood.
e.g. sample_autoregressive https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer.py#L1009 using the logits computed here in the while loop body function https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer.py#L1159
You'd then make the scores part of the estimator spec in this code block:
https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/utils.py#L500

@adarob
Copy link
Collaborator

adarob commented Jul 27, 2020

Please send a PR when you get this to work or let us know if you need additional help.

@allen-q
Copy link
Author

allen-q commented Jul 28, 2020

thanks @craffel and @adarob , I will give it a try. How do you debug mtf models and TF code in general? I used Pytorch in the past and find it a bit hard to debug TF code. Any tips will be appreciated!

@allen-q
Copy link
Author

allen-q commented Aug 11, 2020

Hi @craffel , @adarob , I've spent some time looking at the t5 code and managed to get the scores added to model output. However, the scores don't look quite right when I compare them with the output from the model.score function. It'll be great if you can take a quick look to see where the problem might be. Below are the changes I made:

Changed file 1: /mesh_tensorflow/transformer/transformer.py
Changed function: decode

In the decode function I changed the following block

      return self.decoder.sample_autoregressive(
          partial_sequences,
          temperature=temperature,
          variable_dtype=variable_dtype,
          encoder_output=encoder_output,
          encoder_sequence_id=encoder_sequence_id,
          encoder_inputs=mtf.layers.rename_length_to_memory_length(inputs),
          shared_params=shared_params,
          has_partial_sequences=False,
          encoder_layer_outputs=encoder_layer_outputs)

to

      samples = self.decoder.sample_autoregressive(
          partial_sequences,
          temperature=temperature,
          variable_dtype=variable_dtype,
          encoder_output=encoder_output,
          encoder_sequence_id=encoder_sequence_id,
          encoder_inputs=mtf.layers.rename_length_to_memory_length(inputs),
          shared_params=shared_params,
          has_partial_sequences=False,
          encoder_layer_outputs=encoder_layer_outputs)


      logits, _ = self.call_simple(
              inputs=inputs,
              targets=samples,
              compute_loss=False,
              mode='score',
              variable_dtype=variable_dtype)
      return (samples, logits)

Changed file 2: /mesh_tensorflow/transformer/utils.py
Changed function: my_model_fn

In the my_model_fn function I extended the PREDICT path
elif mode == tf.estimator.ModeKeys.PREDICT:

to

    elif mode == tf.estimator.ModeKeys.PREDICT:
      inputs = mtf_features["inputs"]
      if predict_fn:
        mtf_samples = predict_fn(
            model=transformer_model,
            features=mtf_features,
            variable_dtype=get_variable_dtype())
      elif isinstance(transformer_model, transformer.Unitransformer):
        # pad so that there is enough room for the targets
        inputs = mtf.pad(
            inputs, [0, sequence_length["targets"]], length_dim.name)
        mtf_samples = transformer_model.sample_autoregressive(
            inputs, variable_dtype=get_variable_dtype(),
            remove_partial_sequences=True)
      elif isinstance(
          transformer_model,
          (transformer.Bitransformer, transformer.StudentTeacher)):
        (mtf_samples, logits) = transformer_model.decode(
            inputs, variable_dtype=get_variable_dtype())
      else:
        raise ValueError("unrecognized class")
      
      # use the generated samples as the targets to calculate score
      targets = mtf_features["targets"] = mtf_samples
      batch_dim, length_dim, vocab_dim = logits.shape.dims
      cross_entropy = mtf.layers.softmax_cross_entropy_with_logits(
          logits, mtf_features["targets"], vocab_dim)
      cross_entropy *= mtf.cast(
          mtf.not_equal(targets, 0), cross_entropy.dtype)
      if mode == "delimited_lm":
        cross_entropy *= mtf.cast(mtf.logical_not(
            transformer.delimited_lm_inputs_mask(targets)), cross_entropy.dtype)
      scores = -mtf.reduce_sum(cross_entropy, reduced_dim=length_dim)
      #scores = mtf.exp(scores)
    
      scores = mtf.anonymize(scores)
    
      mtf_samples = mtf.anonymize(mtf_samples)
      inputs = mtf.anonymize(inputs)
      lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)
      inputs = clean_decodes(lowering.export_to_tf_tensor(inputs))
      outputs = clean_decodes(lowering.export_to_tf_tensor(mtf_samples))

      # Detokenize in the graph if supported by vocabulary and accelerator.
      def _maybe_detokenize(ids, vocab):
        if not use_tpu and hasattr(vocab, "decode_tf"):
          return vocab.decode_tf(ids)
        return ids

      inputs = _maybe_detokenize(inputs, inputs_vocabulary(vocabulary))
      outputs = _maybe_detokenize(outputs, targets_vocabulary(vocabulary))

      predictions = {
          "inputs": inputs,
          "outputs": outputs,
          "scores": lowering.export_to_tf_tensor(scores)
      }

The exported model SignatureDef looks like below:

The given SavedModel SignatureDef contains the following input(s):
  inputs['input'] tensor_info:
      dtype: DT_STRING
      shape: (-1)
      name: inputs:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['inputs'] tensor_info:
      dtype: DT_STRING
      shape: (10)
      name: SentenceTokenizer/SentenceTokenizer/SentencepieceDetokenizeOp:0
  outputs['outputs'] tensor_info:
      dtype: DT_STRING
      shape: (10)
      name: SentenceTokenizer_1/SentenceTokenizer/SentencepieceDetokenizeOp:0
  outputs['scores'] tensor_info:
      dtype: DT_FLOAT
      shape: (10)
      name: reshape_9/parallel_0/Reshape:0
Method name is: tensorflow/serving/predict

The problem is the scores look too small(I assume they are log probability) compared to the scores I get from the model.score(input, predicted_text) function.

@daphnei
Copy link
Contributor

daphnei commented Aug 11, 2020

Could you post a few examples of what the scores are with your code vs. the scoring code?

@allen-q
Copy link
Author

allen-q commented Aug 13, 2020

@daphnei , sure, with 10 input text samples, using my code, the scores are:

[-26.102715, -42.592976, -30.208128, -22.404648, -12.230875, -4.1054745, -21.260942, -19.400509, -13.079597, -46.996883]

If I use the model.score function, the output is:

model.score(input_text_list,
            generated_text_list, 
            scores_file='./temp_scores.txt'
           )

[-1.324276, -0.981087, -1.518559, -0.966465, -1.113779, -4.105474, -2.4146799999999997, -2.640263, -1.239723, -1.553466]

@allen-q
Copy link
Author

allen-q commented Aug 13, 2020

@daphnei , I did some debugging and found the root cause of this issue. I think the problem is the raw output of the self.decoder.sample_autoregressive function has tokens after the eos token. In the score function these invalid tokens are removed by the clean_decodes function. After I removed the invalid tokens from the sample_autoregressive output, the scores from my code matches the model.score function. Thanks for your help.

@allen-q
Copy link
Author

allen-q commented Aug 20, 2020

I've raised a PR in the tensorflow-mesh repo to add this feature. @craffel , @daphnei @adarob please feel free to have a look.

@daphnei
Copy link
Contributor

daphnei commented Aug 20, 2020

Would it make sense to return log-probabilities instead of probabilities so as to align with the existing scoring code?

@allen-q
Copy link
Author

allen-q commented Aug 21, 2020

Thanks for your feedback @daphnei . I've thought about returning the log-prob as the current score functions but for the project I'm working on, they need the texts to come with a probability. If I expose the log-prob in the output, I think there will be a post-processing step required to convert it to prob which is not ideal.

There are two options I can think of:

  1. Expose both the scores(log-prob) and probabilities in the output. This will be easier to implement

  2. Pass a parameter to the predict function to indicate whether we want to include score/prob in the output. the parameter can be called something like 'include_scores' with 3 values {None|log_prob|prob}. The issue with this approach is there will be a few places need updating.

Please let me know your thoughts.

@daphnei
Copy link
Contributor

daphnei commented Aug 27, 2020

Sorry for the delayed response. A couple comments:

Personally, I think it is preferable to only output logprobs. It is trivial for downstream code using these scores to convert them to probs, and having an extra param seems like it would just over-complicate things.

On an unrelated note, if possible, I think it would be a good idea to refractor the logic that is redundant with https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/utils.py#L487 to a helper function compute_scores.

@adarob
Copy link
Collaborator

adarob commented Aug 27, 2020

Thanks for the work @allen-q, and +1 to @daphnei's comments. It would be great to get this PR in ASAP.

@allen-q
Copy link
Author

allen-q commented Aug 27, 2020

@daphnei Thanks for the comments and I think it makes sense. I will go make the changes and update the PR once it's ready @adarob . Thank you.

@allen-q
Copy link
Author

allen-q commented Aug 31, 2020

Hey @adarob , @daphnei , @craffel , I've made the suggested changes to the PR. Please take a look and let me know if it looks ok to be merged.

@adarob
Copy link
Collaborator

adarob commented Sep 2, 2020

@daphnei would you mind taking a first pass if you have time since you're more familiar with this code?

@daphnei
Copy link
Contributor

daphnei commented Sep 7, 2020

Looks good to me.

@allen-q
Copy link
Author

allen-q commented Sep 12, 2020

@adarob , thanks for following this up. @daphnei thanks for reviewing the code and I've changed the function name as you suggested and resolved the merge conflict. Would you be able to merge the pull request or it needs to be done by someone else?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants