Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[TGA] Allow setting prefix tokens #3760

Merged
merged 3 commits into from
Jul 6, 2021
Merged

[TGA] Allow setting prefix tokens #3760

merged 3 commits into from
Jul 6, 2021

Conversation

emilydinan
Copy link
Contributor

Patch description
Allow the setting of prefix tokens at generation time. This code already existed but was never accessed (and also seemingly did not work). I fixed some bugs in implementation and added a test to confirm that this works.

Testing steps
Wrote a test.

stephenroller
stephenroller previously approved these changes Jul 1, 2021
@stephenroller stephenroller dismissed their stale review July 1, 2021 15:11

didn't notice test failure

@emilydinan
Copy link
Contributor Author

fixed test failure...happened due to spacing issue

Copy link
Contributor

@klshuster klshuster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Comment on lines +1133 to +1137
prefix_toks = prefix_tokens[:, _ts]
prefix_mask = torch.ones_like(score, dtype=torch.bool)
prefix_mask[
:, :, prefix_toks
] = False # everything except prefix toks should be neginf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks this is way cleaner

@emilydinan emilydinan merged commit dd4fd1f into master Jul 6, 2021
@emilydinan emilydinan deleted the prefix_toks branch July 6, 2021 19:39
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants