Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GPT-2 BPE pre-tokenizer operator leveraging re2 regex library #1459

Merged
merged 8 commits into from
Dec 9, 2021
50 changes: 50 additions & 0 deletions test/data/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tempfile
import uuid
import unittest
import regex as re

import torch
from torchtext.data.functional import (
Expand All @@ -12,6 +13,7 @@
sentencepiece_tokenizer,
custom_replace,
simple_space_split,
gpt2_bpe_tokenizer,
)

from ..common.torchtext_test_case import TorchtextTestCase
Expand Down Expand Up @@ -91,6 +93,54 @@ def test_simple_space_split(self):
self.assertEqual(list(simple_space_split(test_sample))[0],
ref_results)

def test_gpt2_bpe_tokenizer(self):
# Regex pattern for GPT-2 BPE which includes the negative lookahead
# Reference: https://github.com/pytorch/fairseq/blob/main/fairseq/data/encoders/gpt2_bpe_utils.py#L69
gpt2_bpe_pattern = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
test_cases = [
# test spaces
"Lorem ipsum dolor sit amet.",
"Lorem ipsum dolor sit amet.",
"Lorem ipsum dolor sit amet. ",
"Lorem ipsum dolor sit amet ",
"Lorem\x0d\x0dipsum dolor sit amet\r\r",
"Lorem ipsum\x20dolor sit amet",
"Lorem ipsum\x20\x20\x20dolor sit amet",
"Lorem ipsum\x20\x20 dolor sit amet",
# test tabs
"Lorem ipsum dolor sit \t\t\t amet.",
"Lorem ipsum dolor sit \t\t\t\tamet.",
"Lorem ipsum dolor sit \x09\x09amet.",
"Lorem ipsum dolor sit \x09\x09 amet.",
"Lorem ipsum dolor sit \x09\x09 amet. ",
"Lorem ipsum dolor sit \t \tamet.",
"Lorem ipsum dolor sit amet \t",
"Lorem ipsum\tdolor sit amet",
# test carriage returns
"Lorem ipsum\r\r dolor sit amet",
"Lorem ipsum\r\r dolor sit amet\r\r",
"Lorem ipsum \x0d\x0ddolor sit amet.",
"Lorem ipsum\x0ddolor sit amet.",
"Lorem ipsum\x0d\x0d dolor sit amet.",
"Lorem ipsum\x0d\x0d dolor sit amet.\x0d",
# test form feeds
"Lorem ipsum\f\fdolor sit amet\f",
"Lorem ipsum\f\f dolor sit amet\f ",
"Lorem ipsum\x0c\x0c dolor sit amet",
"Lorem \x0c\x0c\x0c\x0cipsum dolor sit amet",
# test vertical tabs
"Lorem ipsum dolor sit\vamet.",
"Lorem ipsum dolor sit\v\vamet.",
"Lorem ipsum dolor sit\v\v amet.",
"Lorem ipsum dolor sit\v\v amet. \v",
"Lorem ipsum dolor sit\x0b\x0b amet. \v ",
"Lorem ipsum dolor sit\x0bamet.",
"Lorem ipsum dolor sit\x0b\x0bamet.",
"Lorem ipsum dolor sit\x0b\x0b amet.",
]
for t in test_cases:
self.assertEqual(re.findall(gpt2_bpe_pattern, t), gpt2_bpe_tokenizer(t))


class ScriptableSP(torch.jit.ScriptModule):
def __init__(self, model_path):
Expand Down
51 changes: 51 additions & 0 deletions torchtext/csrc/gpt2_bpe_tokenizer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#include <algorithm>
#include <gpt2_bpe_tokenizer.h>
#include <regex.h> // @manual

namespace torchtext {
const Regex kGPT2Regex(
"(\\'s|\\'t|\\'re|\\'ve|\\'m|\\'ll|\\'d| ?\\pL+|"
" ?\\pN+| ?[^\\s\\v\\pL\\pN]+|[\\s\\v]+)"
);

bool is_whitespace(const std::string &input) {
for (std::string::const_iterator it = input.begin(); it != input.end(); ++it) {
abhinavarora marked this conversation as resolved.
Show resolved Hide resolved
if (!isspace(*it)) {
return false;
}
}
return true;
}

std::vector<std::string> gpt2_bpe_tokenizer(std::string input) {
abhinavarora marked this conversation as resolved.
Show resolved Hide resolved
std::string token;
std::vector<std::string> tokens;
re2::StringPiece inp(input);
bool append_space = false;
while (kGPT2Regex.FindAndConsume(&inp, &token)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: nit: I feel that the cyclomatic complexity of this main logic, almost hitting the threshold. (though it's not like I calculated or anything), so if, in the future, any update is made, I suggest to consider refactoring. (I understand it's not straight forward though.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I don't expect this to change in future. If any changes occur in logic. I will try to abstract the logic out in an auxiliary self-contained method.

// tokens.push_back(token);
// Check if whitespace
abhinavarora marked this conversation as resolved.
Show resolved Hide resolved
if (is_whitespace(token)) {
append_space = false;
if (inp.empty()) {
tokens.push_back(token);
} else {
if (token.length() > 1) {
tokens.push_back(token.substr(0, token.length() - 1));
}
if (token[token.length() - 1] == ' ') {
append_space = true;
} else {
tokens.push_back(token.substr(token.length() - 1));
}
}
} else if (append_space) {
tokens.push_back(" " + token);
append_space = false;
} else {
tokens.push_back(token);
}
}
return tokens;
}
} // namespace torchtext
6 changes: 6 additions & 0 deletions torchtext/csrc/gpt2_bpe_tokenizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#include <string>
#include <vector>

namespace torchtext {
std::vector<std::string> gpt2_bpe_tokenizer(std::string input);
} // namespace torchtext
4 changes: 4 additions & 0 deletions torchtext/csrc/regex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ std::string Regex::Sub(std::string str, const std::string &repl) const {
return str;
}

bool Regex::FindAndConsume(re2::StringPiece* input, std::string* text) const {
return RE2::FindAndConsume(input, *compiled_pattern_, text);
}

std::string _serialize_regex(const c10::intrusive_ptr<Regex> &self) {
return self->re_str_;
}
Expand Down
2 changes: 2 additions & 0 deletions torchtext/csrc/regex.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <re2/re2.h>
#include <re2/stringpiece.h>
#include <string>
#include <torch/script.h>

Expand All @@ -12,6 +13,7 @@ struct Regex : torch::CustomClassHolder {

Regex(const std::string &re_str);
std::string Sub(std::string str, const std::string &repl) const;
bool FindAndConsume(re2::StringPiece* input, std::string* text) const;
};

std::string _serialize_regex(const c10::intrusive_ptr<Regex> &self);
Expand Down
2 changes: 2 additions & 0 deletions torchtext/csrc/register_pybindings.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <gpt2_bpe_tokenizer.h> // @manual
#include <iostream>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
Expand Down Expand Up @@ -30,6 +31,7 @@ PYBIND11_MODULE(_torchtext, m) {
py::class_<Regex, c10::intrusive_ptr<Regex>>(m, "Regex")
.def(py::init<std::string>())
.def("Sub", &Regex::Sub)
.def("FindAndConsume", &Regex::FindAndConsume)
.def(py::pickle(
// __getstate__
[](const c10::intrusive_ptr<Regex> &self) -> std::string {
Expand Down
3 changes: 3 additions & 0 deletions torchtext/csrc/register_torchbindings.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <gpt2_bpe_tokenizer.h> // @manual
#include <iostream>
#include <regex.h>
#include <regex_tokenizer.h> // @manual
Expand All @@ -11,6 +12,7 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) {
m.class_<Regex>("Regex")
.def(torch::init<std::string>())
.def("Sub", &Regex::Sub)
.def("FindAndConsume", &Regex::Sub)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Regex> &self) -> std::string {
Expand Down Expand Up @@ -124,6 +126,7 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) {
m.def("torchtext::generate_sp_model", &generate_sp_model);
m.def("torchtext::load_sp_model", &load_sp_model);
m.def("torchtext::load_sp_model_string", &load_sp_model_string);
m.def("torchtext::gpt2_bpe_tokenizer", &gpt2_bpe_tokenizer);
}

} // namespace torchtext
17 changes: 17 additions & 0 deletions torchtext/data/functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
import io
import torch
from typing import List

__all__ = [
"generate_sp_model", "load_sp_model",
Expand Down Expand Up @@ -282,3 +283,19 @@ def __getitem__(self, idx):
return self._data[idx]

return _MapStyleDataset(iter_data)


def gpt2_bpe_tokenizer(input: str) -> List[str]:
abhinavarora marked this conversation as resolved.
Show resolved Hide resolved
r"""Regex tokenization for GPT-2 before applying BPE.

Args:
input: the text that needs to be tokenized.

Outputs:
output: list of tokens after applying regex tokenization for GPT-2.

Examples:
>>> from torchtext.data.functional import gpt2_bpe_tokenizer
>>> tokens = gpt2_bpe_tokenizer('hello world')
"""
return torch.ops.torchtext.gpt2_bpe_tokenizer(input)