This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Patch description
This is a followup PR to #4169. In that PR, I added support to log token probabilities and token ranks for outputs of ParlAI models. However, after using it, it became clear that we would like to log additional token-level metadata, such as top 10 tokens and top ranked token (relevant for sampling-based decoding methods).
Rather, than add these features directly, I am instead making the token-level metadata object more flexible. In this PR, each token has associated with it a typed dictionary
_PathSelectionTokenDetails
that contains thetoken_score
andtoken_rank
of the relevant token. No code outside theTreeSearch:select_paths
method implementations and this typed dictionaries' definition make any reference to the specific fields in this dictionary. This makes it easy to override this dictionary's definition and aTreeSearch:select_paths
implementation to add more verbose metadata. Since different research use-cases may want to generate different data token-level metadata, this approach will be more future-proof.Additionally, I make a small change to how token probabilities are logged in nucleus sampling. Instead of logging token probs from the truncated (nucleus) distribution, we will now log token probas from the non-truncated distribution.
Finally, I also return normalized probabilities instead of logprobs for token-level probabilities.
Testing steps
pytest tests/test_tga.py
parlai dm --model-file zoo:unittest/transformer_generator2/model --truncate 1024 -v --task integration_tests:multiturn_nocandidate -ne 1 --inference beam --beam-size 3