Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPT-2 BPE pre-tokenizer operator leveraging re2 regex library #1459

Merged
merged 8 commits into from
Dec 9, 2021
Empty file added test/csrc/__init__.py
Empty file.
55 changes: 55 additions & 0 deletions test/csrc/test_gpt2_bpe_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import regex as re
import torch
import torchtext # noqa: F401
from ..common.torchtext_test_case import TorchtextTestCase


class TestGPT2BPETokenizer(TorchtextTestCase):
def test_gpt2_bpe_pre_tokenizer(self):
# Regex pattern for GPT-2 BPE which includes the negative lookahead
# Reference: https://github.com/pytorch/fairseq/blob/main/fairseq/data/encoders/gpt2_bpe_utils.py#L69
gpt2_bpe_pattern = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
test_cases = [
# test spaces
"Lorem ipsum dolor sit amet.",
"Lorem ipsum dolor sit amet.",
"Lorem ipsum dolor sit amet. ",
"Lorem ipsum dolor sit amet ",
"Lorem\x0d\x0dipsum dolor sit amet\r\r",
"Lorem ipsum\x20dolor sit amet",
"Lorem ipsum\x20\x20\x20dolor sit amet",
"Lorem ipsum\x20\x20 dolor sit amet",
# test tabs
"Lorem ipsum dolor sit \t\t\t amet.",
"Lorem ipsum dolor sit \t\t\t\tamet.",
"Lorem ipsum dolor sit \x09\x09amet.",
"Lorem ipsum dolor sit \x09\x09 amet.",
"Lorem ipsum dolor sit \x09\x09 amet. ",
"Lorem ipsum dolor sit \t \tamet.",
"Lorem ipsum dolor sit amet \t",
"Lorem ipsum\tdolor sit amet",
# test carriage returns
"Lorem ipsum\r\r dolor sit amet",
"Lorem ipsum\r\r dolor sit amet\r\r",
"Lorem ipsum \x0d\x0ddolor sit amet.",
"Lorem ipsum\x0ddolor sit amet.",
"Lorem ipsum\x0d\x0d dolor sit amet.",
"Lorem ipsum\x0d\x0d dolor sit amet.\x0d",
# test form feeds
"Lorem ipsum\f\fdolor sit amet\f",
"Lorem ipsum\f\f dolor sit amet\f ",
"Lorem ipsum\x0c\x0c dolor sit amet",
"Lorem \x0c\x0c\x0c\x0cipsum dolor sit amet",
# test vertical tabs
"Lorem ipsum dolor sit\vamet.",
"Lorem ipsum dolor sit\v\vamet.",
"Lorem ipsum dolor sit\v\v amet.",
"Lorem ipsum dolor sit\v\v amet. \v",
"Lorem ipsum dolor sit\x0b\x0b amet. \v ",
"Lorem ipsum dolor sit\x0bamet.",
"Lorem ipsum dolor sit\x0b\x0bamet.",
"Lorem ipsum dolor sit\x0b\x0b amet.",
]
for t in test_cases:
self.assertEqual(re.findall(gpt2_bpe_pattern, t),
torch.ops.torchtext.gpt2_bpe_pre_tokenizer(t))
68 changes: 68 additions & 0 deletions torchtext/csrc/gpt2_bpe_tokenizer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#include <algorithm>
#include <gpt2_bpe_tokenizer.h>
#include <regex.h> // @manual

namespace torchtext {
const Regex kGPT2Regex(
"(\\'s|\\'t|\\'re|\\'ve|\\'m|\\'ll|\\'d| ?\\pL+|"
" ?\\pN+| ?[^\\s\\v\\pL\\pN]+|[\\s\\v]+)"
);

bool is_whitespace(const std::string &input) {
for (const char& c : input) {
if (!isspace(c)) {
return false;
}
}
return true;
}

std::vector<std::string> gpt2_bpe_pre_tokenizer(std::string input) {
// Python implementation: https://github.com/pytorch/fairseq/blob/main/fairseq/data/encoders/gpt2_bpe_utils.py#L69
// Original regex contains a negative lookahead pattern, which is not
// supported in re2. This implementation modifies the original regex in
// the following two ways:
// 1. Removes negative lookahead and adds a post-processing step instead.
// 2. Replace all [\s] occurences with [\s\v] because re2 does not include
// vertical tab (\v) in whitespace. PCRE and Python re include \v in \s.
//
// Pseudocode of post-processing step:
// - Loop over all tokens
// - IF token is all whitespace:
// - set prepend_space to False
// - IF token is last token, add it to return vector
// - ELSE
// - If token length is >1, add token[0:len(token) - 1] to return list
// - IF token[-1] is space (ascii 32), then carry it over for next token, set append_space = True
// - ELSE make token[-1] its own token and add to return list
// - ELSE IF prepend_space == True, prepend a space to the token and add to return list
// - ELSE, add token to return list
std::string token;
std::vector<std::string> tokens;
re2::StringPiece inp(input);
bool prepend_space = false;
while (kGPT2Regex.FindAndConsume(&inp, &token)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: nit: I feel that the cyclomatic complexity of this main logic, almost hitting the threshold. (though it's not like I calculated or anything), so if, in the future, any update is made, I suggest to consider refactoring. (I understand it's not straight forward though.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I don't expect this to change in future. If any changes occur in logic. I will try to abstract the logic out in an auxiliary self-contained method.

if (is_whitespace(token)) {
prepend_space = false;
if (inp.empty()) { // token is last token
tokens.push_back(token);
} else {
if (token.length() > 1) {
tokens.push_back(token.substr(0, token.length() - 1));
}
if (token[token.length() - 1] == ' ') { // last char is space
prepend_space = true;
} else { // push last whitespace char as a token if it is not a space
tokens.push_back(token.substr(token.length() - 1));
}
}
} else if (prepend_space) {
tokens.push_back(" " + token);
prepend_space = false;
} else {
tokens.push_back(token);
}
}
return tokens;
}
} // namespace torchtext
8 changes: 8 additions & 0 deletions torchtext/csrc/gpt2_bpe_tokenizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include <string>
#include <vector>

namespace torchtext {
// Applies regex based pre-tokenization step for GPT-2 BPE tokenizer
// and returns a list of tokens.
std::vector<std::string> gpt2_bpe_pre_tokenizer(std::string input);
} // namespace torchtext
4 changes: 4 additions & 0 deletions torchtext/csrc/regex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ std::string Regex::Sub(std::string str, const std::string &repl) const {
return str;
}

bool Regex::FindAndConsume(re2::StringPiece* input, std::string* text) const {
return RE2::FindAndConsume(input, *compiled_pattern_, text);
}

std::string _serialize_regex(const c10::intrusive_ptr<Regex> &self) {
return self->re_str_;
}
Expand Down
2 changes: 2 additions & 0 deletions torchtext/csrc/regex.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <re2/re2.h>
#include <re2/stringpiece.h>
#include <string>
#include <torch/script.h>

Expand All @@ -12,6 +13,7 @@ struct Regex : torch::CustomClassHolder {

Regex(const std::string &re_str);
std::string Sub(std::string str, const std::string &repl) const;
bool FindAndConsume(re2::StringPiece* input, std::string* text) const;
};

std::string _serialize_regex(const c10::intrusive_ptr<Regex> &self);
Expand Down
1 change: 1 addition & 0 deletions torchtext/csrc/register_pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ PYBIND11_MODULE(_torchtext, m) {
py::class_<Regex, c10::intrusive_ptr<Regex>>(m, "Regex")
.def(py::init<std::string>())
.def("Sub", &Regex::Sub)
.def("FindAndConsume", &Regex::FindAndConsume)
.def(py::pickle(
// __getstate__
[](const c10::intrusive_ptr<Regex> &self) -> std::string {
Expand Down
2 changes: 2 additions & 0 deletions torchtext/csrc/register_torchbindings.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <gpt2_bpe_tokenizer.h> // @manual
#include <iostream>
#include <regex.h>
#include <regex_tokenizer.h> // @manual
Expand Down Expand Up @@ -124,6 +125,7 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) {
m.def("torchtext::generate_sp_model", &generate_sp_model);
m.def("torchtext::load_sp_model", &load_sp_model);
m.def("torchtext::load_sp_model_string", &load_sp_model_string);
m.def("torchtext::gpt2_bpe_pre_tokenizer", &gpt2_bpe_pre_tokenizer);
}

} // namespace torchtext