-
Notifications
You must be signed in to change notification settings - Fork 2.1k
[torchscript] Support special tokens in torchscript module. #3644
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, minor nits
@@ -481,7 +488,41 @@ def encode(self, text: str) -> List[str]: | |||
""" | |||
if self.add_prefix_space: | |||
text = f' {text}' | |||
return self.helper_encode(text) | |||
|
|||
# constants for readability |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm perhaps it'd help to have a 1-sentence comment about how the special tokens code works at a high level, and maybe also what FINAL and SPLITABLE are?
@@ -62,6 +62,72 @@ def test_token_splitter(self): | |||
if idx + 1 == num_examples: | |||
break | |||
|
|||
def test_special_tokenization(self): | |||
from parlai.core.dict import DictionaryAgent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: these imports could go to the top, right? Or at least the first 2?
from parlai.torchscript.modules import ScriptableDictionaryAgent | ||
|
||
SPECIAL = ['Q00', 'Q01'] | ||
text = "Don't have a Q00, man! Have a Q01 instead." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😆
assert len(tokenized) == 15 | ||
assert sda.vec2txt(tokenized) == text | ||
nice_tok = [sda.ind2tok[i] for i in tokenized] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: a few variables that are assigned to but never used, given linting messages
special_tokenized = sda.txt2vec(text) | ||
assert len(special_tokenized) == 15 | ||
assert sda.vec2txt(special_tokenized) == text | ||
assert special_tokenized != tokenized |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thought: it could be even more explicit to check the actual strings of the output tokens, instead of just their length and whether they match with/without special tokens. No strong opinion on this either way, though
Sorry, landing in expediency. |
Patch description
Add special tokens support for torchscript. The implementation doesn't reflect the original, because the original just used recursion, which torchscript doesn't support.
Testing steps
New CI. Internal testing