Skip to content

Commit

Permalink
Add GPT-2 BPE pre-tokenizer operator leveraging re2 regex library (#1459
Browse files Browse the repository at this point in the history
)

* Add GPT-2 BPE tokenizer operator leveraging re2

* rename method to reflect it is pre-tokenization

* Add comments on gpt-2 bpe pre-tokenization implementation

* remove gpt2_bpe_pre_tokenizer from functional and add unit tests for csrc

* modify is_whitespace to use range based for loop

* add new line at eof

* Remove unnecessary include statement

* Address code review nit comments
  • Loading branch information
abhinavarora authored Dec 9, 2021
1 parent 34b4938 commit 8ef1b15
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 0 deletions.
Empty file added test/csrc/__init__.py
Empty file.
55 changes: 55 additions & 0 deletions test/csrc/test_gpt2_bpe_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import regex as re
import torch
import torchtext # noqa: F401
from ..common.torchtext_test_case import TorchtextTestCase


class TestGPT2BPETokenizer(TorchtextTestCase):
def test_gpt2_bpe_pre_tokenizer(self):
# Regex pattern for GPT-2 BPE which includes the negative lookahead
# Reference: https://github.com/pytorch/fairseq/blob/main/fairseq/data/encoders/gpt2_bpe_utils.py#L69
gpt2_bpe_pattern = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
test_cases = [
# test spaces
"Lorem ipsum dolor sit amet.",
"Lorem ipsum dolor sit amet.",
"Lorem ipsum dolor sit amet. ",
"Lorem ipsum dolor sit amet ",
"Lorem\x0d\x0dipsum dolor sit amet\r\r",
"Lorem ipsum\x20dolor sit amet",
"Lorem ipsum\x20\x20\x20dolor sit amet",
"Lorem ipsum\x20\x20 dolor sit amet",
# test tabs
"Lorem ipsum dolor sit \t\t\t amet.",
"Lorem ipsum dolor sit \t\t\t\tamet.",
"Lorem ipsum dolor sit \x09\x09amet.",
"Lorem ipsum dolor sit \x09\x09 amet.",
"Lorem ipsum dolor sit \x09\x09 amet. ",
"Lorem ipsum dolor sit \t \tamet.",
"Lorem ipsum dolor sit amet \t",
"Lorem ipsum\tdolor sit amet",
# test carriage returns
"Lorem ipsum\r\r dolor sit amet",
"Lorem ipsum\r\r dolor sit amet\r\r",
"Lorem ipsum \x0d\x0ddolor sit amet.",
"Lorem ipsum\x0ddolor sit amet.",
"Lorem ipsum\x0d\x0d dolor sit amet.",
"Lorem ipsum\x0d\x0d dolor sit amet.\x0d",
# test form feeds
"Lorem ipsum\f\fdolor sit amet\f",
"Lorem ipsum\f\f dolor sit amet\f ",
"Lorem ipsum\x0c\x0c dolor sit amet",
"Lorem \x0c\x0c\x0c\x0cipsum dolor sit amet",
# test vertical tabs
"Lorem ipsum dolor sit\vamet.",
"Lorem ipsum dolor sit\v\vamet.",
"Lorem ipsum dolor sit\v\v amet.",
"Lorem ipsum dolor sit\v\v amet. \v",
"Lorem ipsum dolor sit\x0b\x0b amet. \v ",
"Lorem ipsum dolor sit\x0bamet.",
"Lorem ipsum dolor sit\x0b\x0bamet.",
"Lorem ipsum dolor sit\x0b\x0b amet.",
]
for t in test_cases:
self.assertEqual(re.findall(gpt2_bpe_pattern, t),
torch.ops.torchtext.gpt2_bpe_pre_tokenizer(t))
68 changes: 68 additions & 0 deletions torchtext/csrc/gpt2_bpe_tokenizer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#include <algorithm>
#include <gpt2_bpe_tokenizer.h>
#include <regex.h> // @manual

namespace torchtext {
const Regex kGPT2Regex(
"(\\'s|\\'t|\\'re|\\'ve|\\'m|\\'ll|\\'d| ?\\pL+|"
" ?\\pN+| ?[^\\s\\v\\pL\\pN]+|[\\s\\v]+)"
);

bool is_whitespace(const std::string &input) {
for (const char& c : input) {
if (!isspace(c)) {
return false;
}
}
return true;
}

std::vector<std::string> gpt2_bpe_pre_tokenizer(std::string input) {
// Python implementation: https://github.com/pytorch/fairseq/blob/main/fairseq/data/encoders/gpt2_bpe_utils.py#L69
// Original regex contains a negative lookahead pattern, which is not
// supported in re2. This implementation modifies the original regex in
// the following two ways:
// 1. Removes negative lookahead and adds a post-processing step instead.
// 2. Replace all [\s] occurences with [\s\v] because re2 does not include
// vertical tab (\v) in whitespace. PCRE and Python re include \v in \s.
//
// Pseudocode of post-processing step:
// - Loop over all tokens
// - IF token is all whitespace:
// - set prepend_space to False
// - IF token is last token, add it to return vector
// - ELSE
// - If token length is >1, add token[0:len(token) - 1] to return list
// - IF token[-1] is space (ascii 32), then carry it over for next token, set append_space = True
// - ELSE make token[-1] its own token and add to return list
// - ELSE IF prepend_space == True, prepend a space to the token and add to return list
// - ELSE, add token to return list
std::string token;
std::vector<std::string> tokens;
re2::StringPiece inp(input);
bool prepend_space = false;
while (kGPT2Regex.FindAndConsume(&inp, &token)) {
if (is_whitespace(token)) {
prepend_space = false;
if (inp.empty()) { // token is last token
tokens.push_back(token);
} else {
if (token.length() > 1) {
tokens.push_back(token.substr(0, token.length() - 1));
}
if (token[token.length() - 1] == ' ') { // last char is space
prepend_space = true;
} else { // push last whitespace char as a token if it is not a space
tokens.push_back(token.substr(token.length() - 1));
}
}
} else if (prepend_space) {
tokens.push_back(" " + token);
prepend_space = false;
} else {
tokens.push_back(token);
}
}
return tokens;
}
} // namespace torchtext
8 changes: 8 additions & 0 deletions torchtext/csrc/gpt2_bpe_tokenizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include <string>
#include <vector>

namespace torchtext {
// Applies regex based pre-tokenization step for GPT-2 BPE tokenizer
// and returns a list of tokens.
std::vector<std::string> gpt2_bpe_pre_tokenizer(std::string input);
} // namespace torchtext
4 changes: 4 additions & 0 deletions torchtext/csrc/regex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ std::string Regex::Sub(std::string str, const std::string &repl) const {
return str;
}

bool Regex::FindAndConsume(re2::StringPiece* input, std::string* text) const {
return RE2::FindAndConsume(input, *compiled_pattern_, text);
}

std::string _serialize_regex(const c10::intrusive_ptr<Regex> &self) {
return self->re_str_;
}
Expand Down
2 changes: 2 additions & 0 deletions torchtext/csrc/regex.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <re2/re2.h>
#include <re2/stringpiece.h>
#include <string>
#include <torch/script.h>

Expand All @@ -12,6 +13,7 @@ struct Regex : torch::CustomClassHolder {

Regex(const std::string &re_str);
std::string Sub(std::string str, const std::string &repl) const;
bool FindAndConsume(re2::StringPiece* input, std::string* text) const;
};

std::string _serialize_regex(const c10::intrusive_ptr<Regex> &self);
Expand Down
1 change: 1 addition & 0 deletions torchtext/csrc/register_pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ PYBIND11_MODULE(_torchtext, m) {
py::class_<Regex, c10::intrusive_ptr<Regex>>(m, "Regex")
.def(py::init<std::string>())
.def("Sub", &Regex::Sub)
.def("FindAndConsume", &Regex::FindAndConsume)
.def(py::pickle(
// __getstate__
[](const c10::intrusive_ptr<Regex> &self) -> std::string {
Expand Down
2 changes: 2 additions & 0 deletions torchtext/csrc/register_torchbindings.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <gpt2_bpe_tokenizer.h> // @manual
#include <iostream>
#include <regex.h>
#include <regex_tokenizer.h> // @manual
Expand Down Expand Up @@ -124,6 +125,7 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) {
m.def("torchtext::generate_sp_model", &generate_sp_model);
m.def("torchtext::load_sp_model", &load_sp_model);
m.def("torchtext::load_sp_model_string", &load_sp_model_string);
m.def("torchtext::gpt2_bpe_pre_tokenizer", &gpt2_bpe_pre_tokenizer);
}

} // namespace torchtext

0 comments on commit 8ef1b15

Please sign in to comment.