Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[torchscript] Support special tokens in torchscript module. #3644

Merged
merged 1 commit into from
May 18, 2021

Conversation

stephenroller
Copy link
Contributor

Patch description
Add special tokens support for torchscript. The implementation doesn't reflect the original, because the original just used recursion, which torchscript doesn't support.

Testing steps
New CI. Internal testing

Copy link
Contributor

@EricMichaelSmith EricMichaelSmith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, minor nits

@@ -481,7 +488,41 @@ def encode(self, text: str) -> List[str]:
"""
if self.add_prefix_space:
text = f' {text}'
return self.helper_encode(text)

# constants for readability
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm perhaps it'd help to have a 1-sentence comment about how the special tokens code works at a high level, and maybe also what FINAL and SPLITABLE are?

@@ -62,6 +62,72 @@ def test_token_splitter(self):
if idx + 1 == num_examples:
break

def test_special_tokenization(self):
from parlai.core.dict import DictionaryAgent
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: these imports could go to the top, right? Or at least the first 2?

from parlai.torchscript.modules import ScriptableDictionaryAgent

SPECIAL = ['Q00', 'Q01']
text = "Don't have a Q00, man! Have a Q01 instead."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😆

assert len(tokenized) == 15
assert sda.vec2txt(tokenized) == text
nice_tok = [sda.ind2tok[i] for i in tokenized]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: a few variables that are assigned to but never used, given linting messages

special_tokenized = sda.txt2vec(text)
assert len(special_tokenized) == 15
assert sda.vec2txt(special_tokenized) == text
assert special_tokenized != tokenized
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thought: it could be even more explicit to check the actual strings of the output tokens, instead of just their length and whether they match with/without special tokens. No strong opinion on this either way, though

@stephenroller
Copy link
Contributor Author

Sorry, landing in expediency.

@stephenroller stephenroller merged commit 3bf87ea into master May 18, 2021
@stephenroller stephenroller deleted the torchspecial branch May 18, 2021 17:46
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants