Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

splitting registration and refactoring vocab.py module #1352

Merged
merged 5 commits into from
Jul 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .circleci/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ dependencies:
- sphinx
- sphinx-rtd-theme
- tqdm
- expecttest
- https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.0.0/de_core_news_sm-3.0.0.tar.gz#egg=de_core_news_sm==3.0.0
- https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm==3.0.0
1 change: 1 addition & 0 deletions .circleci/unittest/windows/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ dependencies:
- tqdm
- certifi
- future
- expecttest
- https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.0.0/de_core_news_sm-3.0.0.tar.gz#egg=de_core_news_sm==3.0.0
- https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm==3.0.0
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <torch/script.h>
#include <vectors.h> // @manual
#include <vocab.h> // @manual
#include <vocab_factory.h>

namespace torchtext {

namespace py = pybind11;
Expand Down Expand Up @@ -155,126 +157,8 @@ PYBIND11_MODULE(_torchtext, m) {
&_load_token_and_vectors_from_file);
m.def("_load_vocab_from_file", &_load_vocab_from_file);
m.def("_build_vocab_from_text_file", &build_vocab_from_text_file);
m.def("_build_vocab_from_text_file_using_python_tokenizer", &_build_vocab_from_text_file_using_python_tokenizer);
}

TORCH_LIBRARY_FRAGMENT(torchtext, m) {
m.class_<Regex>("Regex")
.def(torch::init<std::string>())
.def("Sub", &Regex::Sub)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Regex> &self) -> std::string {
return _serialize_regex(self);
},
// __setstate__
[](std::string state) -> c10::intrusive_ptr<Regex> {
return _deserialize_regex(std::move(state));
});

m.class_<RegexTokenizer>("RegexTokenizer")
.def(torch::init<std::vector<std::string>, std::vector<std::string>,
bool>())
.def("forward", &RegexTokenizer::forward)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<RegexTokenizer> &self)
-> RegexTokenizerStates {
return _serialize_regex_tokenizer(self);
},
// __setstate__
[](RegexTokenizerStates states)
-> c10::intrusive_ptr<RegexTokenizer> {
return _deserialize_regex_tokenizer(std::move(states));
});

m.class_<SentencePiece>("SentencePiece")
.def(torch::init<std::string>())
.def("Encode", &SentencePiece::Encode)
.def("EncodeAsIds", &SentencePiece::EncodeAsIds)
.def("DecodeIds", &SentencePiece::DecodeIds)
.def("EncodeAsPieces", &SentencePiece::EncodeAsPieces)
.def("DecodePieces", &SentencePiece::DecodePieces)
.def("GetPieceSize", &SentencePiece::GetPieceSize)
.def("unk_id", &SentencePiece::unk_id)
.def("PieceToId", &SentencePiece::PieceToId)
.def("IdToPiece", &SentencePiece::IdToPiece)
.def_pickle(
// The underlying content of SentencePiece contains byte string,
// and returing it as std::string cause UTF8 decoding error.
// Since TorchScript does not support byte string, we use byte Tensor
// to pass around the data.
// __getstate__
[](const c10::intrusive_ptr<SentencePiece> &self) -> torch::Tensor {
auto *data =
static_cast<void *>(const_cast<char *>(self->content_.data()));
auto numel = static_cast<int64_t>(self->content_.size());
return torch::from_blob(data, {numel}, {torch::kUInt8}).clone();
},
// __setstate__
[](torch::Tensor state) -> c10::intrusive_ptr<SentencePiece> {
auto *data = static_cast<char *>(state.data_ptr());
auto numel = state.size(0);
return c10::make_intrusive<SentencePiece>(std::string(data, numel));
});

m.class_<Vectors>("Vectors")
.def(torch::init<std::vector<std::string>, std::vector<std::int64_t>,
torch::Tensor, torch::Tensor>())
.def("__getitem__", &Vectors::__getitem__)
.def("lookup_vectors", &Vectors::lookup_vectors)
.def("__setitem__", &Vectors::__setitem__)
.def("__len__", &Vectors::__len__)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Vectors> &self) -> VectorsStates {
return _serialize_vectors(self);
},
// __setstate__
[](VectorsStates states) -> c10::intrusive_ptr<Vectors> {
return _deserialize_vectors(states);
});

m.class_<Vocab>("Vocab")
.def(torch::init<StringList, c10::optional<int64_t>>())
.def("__contains__",
[](const c10::intrusive_ptr<Vocab> &self, const std::string &item)
-> bool { return self->__contains__(c10::string_view{item}); })
.def("__getitem__",
[](const c10::intrusive_ptr<Vocab> &self, const std::string &item)
-> int64_t { return self->__getitem__(c10::string_view{item}); })
.def("insert_token", &Vocab::insert_token)
.def("__len__", &Vocab::__len__)
.def("set_default_index", &Vocab::set_default_index)
.def("get_default_index", &Vocab::get_default_index)
.def("append_token", &Vocab::append_token)
.def("lookup_token", &Vocab::lookup_token)
.def("lookup_tokens", &Vocab::lookup_tokens)
.def("lookup_indices",
[](const c10::intrusive_ptr<Vocab> &self,
const std::vector<std::string> &items) {
std::vector<int64_t> indices(items.size());
int64_t counter = 0;
for (const auto &item : items) {
indices[counter++] = self->__getitem__(c10::string_view{item});
}
return indices;
})
.def("get_stoi", &Vocab::get_stoi)
.def("get_itos", &Vocab::get_itos)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Vocab> &self) -> VocabStates {
return _serialize_vocab(self);
},
// __setstate__
[](VocabStates states) -> c10::intrusive_ptr<Vocab> {
return _deserialize_vocab(states);
});

m.def("torchtext::generate_sp_model", &generate_sp_model);
m.def("torchtext::load_sp_model", &load_sp_model);
m.def("torchtext::load_sp_model_string", &load_sp_model_string);
m.def("_build_vocab_from_text_file_using_python_tokenizer",
&_build_vocab_from_text_file_using_python_tokenizer);
}

} // namespace torchtext
129 changes: 129 additions & 0 deletions torchtext/csrc/register_torchbindings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#include <iostream>
#include <regex.h>
#include <regex_tokenizer.h> // @manual
#include <sentencepiece.h> // @manual
#include <torch/script.h>
#include <vectors.h> // @manual
#include <vocab.h> // @manual
namespace torchtext {

TORCH_LIBRARY_FRAGMENT(torchtext, m) {
m.class_<Regex>("Regex")
.def(torch::init<std::string>())
.def("Sub", &Regex::Sub)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Regex> &self) -> std::string {
return _serialize_regex(self);
},
// __setstate__
[](std::string state) -> c10::intrusive_ptr<Regex> {
return _deserialize_regex(std::move(state));
});

m.class_<RegexTokenizer>("RegexTokenizer")
.def(torch::init<std::vector<std::string>, std::vector<std::string>,
bool>())
.def("forward", &RegexTokenizer::forward)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<RegexTokenizer> &self)
-> RegexTokenizerStates {
return _serialize_regex_tokenizer(self);
},
// __setstate__
[](RegexTokenizerStates states)
-> c10::intrusive_ptr<RegexTokenizer> {
return _deserialize_regex_tokenizer(std::move(states));
});

m.class_<SentencePiece>("SentencePiece")
.def(torch::init<std::string>())
.def("Encode", &SentencePiece::Encode)
.def("EncodeAsIds", &SentencePiece::EncodeAsIds)
.def("DecodeIds", &SentencePiece::DecodeIds)
.def("EncodeAsPieces", &SentencePiece::EncodeAsPieces)
.def("DecodePieces", &SentencePiece::DecodePieces)
.def("GetPieceSize", &SentencePiece::GetPieceSize)
.def("unk_id", &SentencePiece::unk_id)
.def("PieceToId", &SentencePiece::PieceToId)
.def("IdToPiece", &SentencePiece::IdToPiece)
.def_pickle(
// The underlying content of SentencePiece contains byte string,
// and returing it as std::string cause UTF8 decoding error.
// Since TorchScript does not support byte string, we use byte Tensor
// to pass around the data.
// __getstate__
[](const c10::intrusive_ptr<SentencePiece> &self) -> torch::Tensor {
auto *data =
static_cast<void *>(const_cast<char *>(self->content_.data()));
auto numel = static_cast<int64_t>(self->content_.size());
return torch::from_blob(data, {numel}, {torch::kUInt8}).clone();
},
// __setstate__
[](torch::Tensor state) -> c10::intrusive_ptr<SentencePiece> {
auto *data = static_cast<char *>(state.data_ptr());
auto numel = state.size(0);
return c10::make_intrusive<SentencePiece>(std::string(data, numel));
});

m.class_<Vectors>("Vectors")
.def(torch::init<std::vector<std::string>, std::vector<std::int64_t>,
torch::Tensor, torch::Tensor>())
.def("__getitem__", &Vectors::__getitem__)
.def("lookup_vectors", &Vectors::lookup_vectors)
.def("__setitem__", &Vectors::__setitem__)
.def("__len__", &Vectors::__len__)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Vectors> &self) -> VectorsStates {
return _serialize_vectors(self);
},
// __setstate__
[](VectorsStates states) -> c10::intrusive_ptr<Vectors> {
return _deserialize_vectors(states);
});

m.class_<Vocab>("Vocab")
.def(torch::init<StringList, c10::optional<int64_t>>())
.def("__contains__",
[](const c10::intrusive_ptr<Vocab> &self, const std::string &item)
-> bool { return self->__contains__(c10::string_view{item}); })
.def("__getitem__",
[](const c10::intrusive_ptr<Vocab> &self, const std::string &item)
-> int64_t { return self->__getitem__(c10::string_view{item}); })
.def("insert_token", &Vocab::insert_token)
.def("__len__", &Vocab::__len__)
.def("set_default_index", &Vocab::set_default_index)
.def("get_default_index", &Vocab::get_default_index)
.def("append_token", &Vocab::append_token)
.def("lookup_token", &Vocab::lookup_token)
.def("lookup_tokens", &Vocab::lookup_tokens)
.def("lookup_indices",
[](const c10::intrusive_ptr<Vocab> &self,
const std::vector<std::string> &items) {
std::vector<int64_t> indices(items.size());
int64_t counter = 0;
for (const auto &item : items) {
indices[counter++] = self->__getitem__(c10::string_view{item});
}
return indices;
})
.def("get_stoi", &Vocab::get_stoi)
.def("get_itos", &Vocab::get_itos)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Vocab> &self) -> VocabStates {
return _serialize_vocab(self);
},
// __setstate__
[](VocabStates states) -> c10::intrusive_ptr<Vocab> {
return _deserialize_vocab(states);
});

m.def("torchtext::generate_sp_model", &generate_sp_model);
m.def("torchtext::load_sp_model", &load_sp_model);
m.def("torchtext::load_sp_model_string", &load_sp_model_string);
}

} // namespace torchtext
59 changes: 0 additions & 59 deletions torchtext/csrc/vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,17 +191,6 @@ void parse_raw_text_file_chunk(const std::string &file_path, size_t offset,
}
}

// sorting using a custom object
struct CompareTokens {
bool operator()(const std::pair<std::string, int64_t> &a,
const std::pair<std::string, int64_t> &b) {
if (a.second == b.second) {
return a.first < b.first;
}
return a.second > b.second;
}
};

StringList
_concat_tokens(std::vector<std::shared_ptr<IndexDict>> chunk_counters,
const int64_t min_freq, const int64_t num_lines,
Expand Down Expand Up @@ -345,54 +334,6 @@ Vocab _build_vocab_from_text_file(const std::string &file_path,
return Vocab(std::move(tokens));
}

Vocab _build_vocab_from_text_file_using_python_tokenizer(
const std::string &file_path, const int64_t min_freq,
py::object tokenizer) {
// find number of lines
int64_t num_lines = _infer_lines(file_path);
// Read text from file and add tokens
std::ifstream fin(file_path, std::ios::in);
TORCH_CHECK(fin.is_open(), "Cannot open input file " + file_path);

IndexDict counter;
std::string line;
for (int64_t i = 0; i < num_lines; i++) {
std::getline(fin, line);
std::vector<std::string> token_list =
tokenizer(line).cast<std::vector<std::string>>();

for (size_t i = 0; i < token_list.size(); i++) {
std::string token = token_list[i];

if (counter.find(token) == counter.end()) {
counter[token] = 1;
} else {
counter[token] += 1;
}
}
}

// create tokens-frequency pairs
std::vector<std::pair<std::string, int64_t>> token_freq_pairs;
for (const auto &item : counter) {
if (item.second >= min_freq) {
token_freq_pairs.push_back(item);
}
}

// sort tokens by frequency
CompareTokens compare_tokens;
std::sort(token_freq_pairs.begin(), token_freq_pairs.end(), compare_tokens);

// Create final list of tokens
StringList tokens;
for (const auto &token_freq_pair : token_freq_pairs) {
tokens.push_back(token_freq_pair.first);
}

return Vocab(std::move(tokens));
}

VocabStates _serialize_vocab(const c10::intrusive_ptr<Vocab> &self) {
std::vector<int64_t> integers;
StringList strings = self->itos_;
Expand Down
20 changes: 14 additions & 6 deletions torchtext/csrc/vocab.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
#pragma once
#include <algorithm>
#include <c10/util/string_view.h>
#include <pybind11/pybind11.h>
#include <torch/script.h>

namespace py = pybind11;

namespace torchtext {

typedef std::vector<std::string> StringList;
Expand All @@ -14,6 +12,19 @@ typedef std::tuple<std::string, std::vector<int64_t>, std::vector<std::string>,
std::vector<torch::Tensor>>
VocabStates;

// sorting using a custom object
struct CompareTokens {
bool operator()(const std::pair<std::string, int64_t> &a,
const std::pair<std::string, int64_t> &b) {
if (a.second == b.second) {
return a.first < b.first;
}
return a.second > b.second;
}
};

int64_t _infer_lines(const std::string &file_path);

struct Vocab : torch::CustomClassHolder {
static const int32_t MAX_VOCAB_SIZE = 30000000;
int64_t unk_index_;
Expand Down Expand Up @@ -79,7 +90,4 @@ Vocab _build_vocab_from_text_file(const std::string &file_path,
const int64_t min_freq,
const int64_t num_cpus,
torch::jit::script::Module tokenizer);
Vocab _build_vocab_from_text_file_using_python_tokenizer(
const std::string &file_path, const int64_t min_freq, py::object tokenizer);

} // namespace torchtext
Loading