-
Notifications
You must be signed in to change notification settings - Fork 811
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add GPT-2 BPE pre-tokenizer operator leveraging re2 regex library (#1459
) * Add GPT-2 BPE tokenizer operator leveraging re2 * rename method to reflect it is pre-tokenization * Add comments on gpt-2 bpe pre-tokenization implementation * remove gpt2_bpe_pre_tokenizer from functional and add unit tests for csrc * modify is_whitespace to use range based for loop * add new line at eof * Remove unnecessary include statement * Address code review nit comments
- Loading branch information
1 parent
34b4938
commit 8ef1b15
Showing
8 changed files
with
140 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import regex as re | ||
import torch | ||
import torchtext # noqa: F401 | ||
from ..common.torchtext_test_case import TorchtextTestCase | ||
|
||
|
||
class TestGPT2BPETokenizer(TorchtextTestCase): | ||
def test_gpt2_bpe_pre_tokenizer(self): | ||
# Regex pattern for GPT-2 BPE which includes the negative lookahead | ||
# Reference: https://github.com/pytorch/fairseq/blob/main/fairseq/data/encoders/gpt2_bpe_utils.py#L69 | ||
gpt2_bpe_pattern = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") | ||
test_cases = [ | ||
# test spaces | ||
"Lorem ipsum dolor sit amet.", | ||
"Lorem ipsum dolor sit amet.", | ||
"Lorem ipsum dolor sit amet. ", | ||
"Lorem ipsum dolor sit amet ", | ||
"Lorem\x0d\x0dipsum dolor sit amet\r\r", | ||
"Lorem ipsum\x20dolor sit amet", | ||
"Lorem ipsum\x20\x20\x20dolor sit amet", | ||
"Lorem ipsum\x20\x20 dolor sit amet", | ||
# test tabs | ||
"Lorem ipsum dolor sit \t\t\t amet.", | ||
"Lorem ipsum dolor sit \t\t\t\tamet.", | ||
"Lorem ipsum dolor sit \x09\x09amet.", | ||
"Lorem ipsum dolor sit \x09\x09 amet.", | ||
"Lorem ipsum dolor sit \x09\x09 amet. ", | ||
"Lorem ipsum dolor sit \t \tamet.", | ||
"Lorem ipsum dolor sit amet \t", | ||
"Lorem ipsum\tdolor sit amet", | ||
# test carriage returns | ||
"Lorem ipsum\r\r dolor sit amet", | ||
"Lorem ipsum\r\r dolor sit amet\r\r", | ||
"Lorem ipsum \x0d\x0ddolor sit amet.", | ||
"Lorem ipsum\x0ddolor sit amet.", | ||
"Lorem ipsum\x0d\x0d dolor sit amet.", | ||
"Lorem ipsum\x0d\x0d dolor sit amet.\x0d", | ||
# test form feeds | ||
"Lorem ipsum\f\fdolor sit amet\f", | ||
"Lorem ipsum\f\f dolor sit amet\f ", | ||
"Lorem ipsum\x0c\x0c dolor sit amet", | ||
"Lorem \x0c\x0c\x0c\x0cipsum dolor sit amet", | ||
# test vertical tabs | ||
"Lorem ipsum dolor sit\vamet.", | ||
"Lorem ipsum dolor sit\v\vamet.", | ||
"Lorem ipsum dolor sit\v\v amet.", | ||
"Lorem ipsum dolor sit\v\v amet. \v", | ||
"Lorem ipsum dolor sit\x0b\x0b amet. \v ", | ||
"Lorem ipsum dolor sit\x0bamet.", | ||
"Lorem ipsum dolor sit\x0b\x0bamet.", | ||
"Lorem ipsum dolor sit\x0b\x0b amet.", | ||
] | ||
for t in test_cases: | ||
self.assertEqual(re.findall(gpt2_bpe_pattern, t), | ||
torch.ops.torchtext.gpt2_bpe_pre_tokenizer(t)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
#include <algorithm> | ||
#include <gpt2_bpe_tokenizer.h> | ||
#include <regex.h> // @manual | ||
|
||
namespace torchtext { | ||
const Regex kGPT2Regex( | ||
"(\\'s|\\'t|\\'re|\\'ve|\\'m|\\'ll|\\'d| ?\\pL+|" | ||
" ?\\pN+| ?[^\\s\\v\\pL\\pN]+|[\\s\\v]+)" | ||
); | ||
|
||
bool is_whitespace(const std::string &input) { | ||
for (const char& c : input) { | ||
if (!isspace(c)) { | ||
return false; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
std::vector<std::string> gpt2_bpe_pre_tokenizer(std::string input) { | ||
// Python implementation: https://github.com/pytorch/fairseq/blob/main/fairseq/data/encoders/gpt2_bpe_utils.py#L69 | ||
// Original regex contains a negative lookahead pattern, which is not | ||
// supported in re2. This implementation modifies the original regex in | ||
// the following two ways: | ||
// 1. Removes negative lookahead and adds a post-processing step instead. | ||
// 2. Replace all [\s] occurences with [\s\v] because re2 does not include | ||
// vertical tab (\v) in whitespace. PCRE and Python re include \v in \s. | ||
// | ||
// Pseudocode of post-processing step: | ||
// - Loop over all tokens | ||
// - IF token is all whitespace: | ||
// - set prepend_space to False | ||
// - IF token is last token, add it to return vector | ||
// - ELSE | ||
// - If token length is >1, add token[0:len(token) - 1] to return list | ||
// - IF token[-1] is space (ascii 32), then carry it over for next token, set append_space = True | ||
// - ELSE make token[-1] its own token and add to return list | ||
// - ELSE IF prepend_space == True, prepend a space to the token and add to return list | ||
// - ELSE, add token to return list | ||
std::string token; | ||
std::vector<std::string> tokens; | ||
re2::StringPiece inp(input); | ||
bool prepend_space = false; | ||
while (kGPT2Regex.FindAndConsume(&inp, &token)) { | ||
if (is_whitespace(token)) { | ||
prepend_space = false; | ||
if (inp.empty()) { // token is last token | ||
tokens.push_back(token); | ||
} else { | ||
if (token.length() > 1) { | ||
tokens.push_back(token.substr(0, token.length() - 1)); | ||
} | ||
if (token[token.length() - 1] == ' ') { // last char is space | ||
prepend_space = true; | ||
} else { // push last whitespace char as a token if it is not a space | ||
tokens.push_back(token.substr(token.length() - 1)); | ||
} | ||
} | ||
} else if (prepend_space) { | ||
tokens.push_back(" " + token); | ||
prepend_space = false; | ||
} else { | ||
tokens.push_back(token); | ||
} | ||
} | ||
return tokens; | ||
} | ||
} // namespace torchtext |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
#include <string> | ||
#include <vector> | ||
|
||
namespace torchtext { | ||
// Applies regex based pre-tokenization step for GPT-2 BPE tokenizer | ||
// and returns a list of tokens. | ||
std::vector<std::string> gpt2_bpe_pre_tokenizer(std::string input); | ||
} // namespace torchtext |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters