Skip to content

Commit

Permalink
Merge branch 'master' into text_classification
Browse files Browse the repository at this point in the history
  • Loading branch information
parmeet authored Mar 24, 2021
2 parents d027ab6 + 50bd1b3 commit 257db3a
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 190 deletions.
12 changes: 7 additions & 5 deletions benchmark/benchmark_experimental_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time

import torch
from torchtext.experimental.datasets import AG_NEWS
from torchtext.experimental.datasets import DATASETS
from torchtext.experimental.vocab import (
vocab as VocabExperimental,
load_vocab_from_file,
Expand Down Expand Up @@ -76,7 +76,7 @@ def benchmark_experimental_vocab_construction(vocab_file_path, is_raw_text=True,
print("Construction time:", time.monotonic() - t0)


def benchmark_experimental_vocab_lookup(vocab_file_path=None):
def benchmark_experimental_vocab_lookup(vocab_file_path=None, dataset = 'AG_NEWS'):
def _run_benchmark_lookup(tokens, vocab):
t0 = time.monotonic()
# list lookup
Expand All @@ -94,7 +94,7 @@ def _run_benchmark_lookup(tokens, vocab):
tokens = []
tokens_lists = []

train = AG_NEWS(split='train')
train = DATASETS[dataset](split='train')
vocab = train.get_vocab()
for (_, text) in train:
cur_tokens = []
Expand Down Expand Up @@ -124,7 +124,7 @@ def token_iterator(file_path):
v_experimental = load_vocab_from_file(f)
print("Construction time:", time.monotonic() - t0)
else:
print("Loading Vocab from AG News")
print("Loading Vocab from {}".format(dataset))
counter = Counter(tokens)
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True)
ordered_dict = OrderedDict(sorted_by_freq_tuples)
Expand Down Expand Up @@ -174,11 +174,13 @@ def token_iterator(file_path):
help='The name of vocab file used for construction')
parser.add_argument('--vocab-filename-lookup', type=str, default=None,
help='The name of vocab file used for lookup')
parser.add_argument('--dataset', type=str, default='AG_NEWS',
help='The name of vocab file used for lookup')
args = parser.parse_args()

if args.run_construction_benchmark:
print("is_legacy", args.is_legacy)
benchmark_experimental_vocab_construction(args.vocab_filename_construction,
is_raw_text=args.is_raw_text, is_legacy=args.is_legacy)
else:
benchmark_experimental_vocab_lookup(args.vocab_filename_lookup)
benchmark_experimental_vocab_lookup(args.vocab_filename_lookup, args.dataset)
243 changes: 140 additions & 103 deletions torchtext/csrc/register_bindings.cpp
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
#include <iostream>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <regex.h>
#include <regex_tokenizer.h> // @manual
#include <sentencepiece.h> // @manual
#include <regex_tokenizer.h> // @manual
#include <sentencepiece.h> // @manual
#include <torch/csrc/jit/python/pybind_utils.h> // @manual
#include <torch/csrc/utils/pybind.h> // @manual
#include <torch/csrc/utils/pybind.h> // @manual
#include <torch/script.h>
#include <vectors.h> // @manual
#include <vocab.h> // @manual

namespace torchtext {

namespace py = pybind11;

namespace {
Vocab build_vocab_from_text_file(const std::string &file_path,
const std::string &unk_token,
const int64_t min_freq,
const int64_t num_cpus,
const int64_t min_freq, const int64_t num_cpus,
py::object fn) {
torch::jit::script::Module module(*torch::jit::as_module(fn));
return _build_vocab_from_text_file(file_path, unk_token, min_freq, num_cpus, module);
return _build_vocab_from_text_file(file_path, unk_token, min_freq, num_cpus,
module);
}
} // namespace

Expand All @@ -40,23 +40,27 @@ PYBIND11_MODULE(_torchtext, m) {
return _deserialize_regex(std::move(state));
}));

py::class_<RegexTokenizer, c10::intrusive_ptr<RegexTokenizer>>(m, "RegexTokenizer")
py::class_<RegexTokenizer, c10::intrusive_ptr<RegexTokenizer>>(
m, "RegexTokenizer")
.def_readonly("patterns_", &RegexTokenizer::patterns_)
.def_readonly("replacements_", &RegexTokenizer::replacements_)
.def_readonly("to_lower_", &RegexTokenizer::to_lower_)
.def(py::init<std::vector<std::string>, std::vector<std::string>, bool>())
.def("forward", &RegexTokenizer::forward)
.def(py::pickle(
// __getstate__
[](const c10::intrusive_ptr<RegexTokenizer> &self) -> RegexTokenizerStates {
[](const c10::intrusive_ptr<RegexTokenizer> &self)
-> RegexTokenizerStates {
return _serialize_regex_tokenizer(self);
},
// __setstate__
[](RegexTokenizerStates states) -> c10::intrusive_ptr<RegexTokenizer> {
[](RegexTokenizerStates states)
-> c10::intrusive_ptr<RegexTokenizer> {
return _deserialize_regex_tokenizer(std::move(states));
}));

py::class_<SentencePiece, c10::intrusive_ptr<SentencePiece>>(m, "SentencePiece")
py::class_<SentencePiece, c10::intrusive_ptr<SentencePiece>>(m,
"SentencePiece")
.def(py::init<std::string>())
.def("_return_content",
[](const SentencePiece &self) { return py::bytes(self.content_); })
Expand All @@ -70,14 +74,14 @@ PYBIND11_MODULE(_torchtext, m) {
.def("PieceToId", &SentencePiece::PieceToId)
.def("IdToPiece", &SentencePiece::IdToPiece)
.def(py::pickle(
// __getstate__
[](const c10::intrusive_ptr<SentencePiece> &self) -> py::bytes{
return py::bytes(self->content_);
},
// __setstate__
[](py::bytes state) -> c10::intrusive_ptr<SentencePiece> {
return c10::make_intrusive<SentencePiece>(std::string(state));
}));
// __getstate__
[](const c10::intrusive_ptr<SentencePiece> &self) -> py::bytes {
return py::bytes(self->content_);
},
// __setstate__
[](py::bytes state) -> c10::intrusive_ptr<SentencePiece> {
return c10::make_intrusive<SentencePiece>(std::string(state));
}));

py::class_<Vectors, c10::intrusive_ptr<Vectors>>(m, "Vectors")
.def(py::init<std::vector<std::string>, std::vector<int64_t>,
Expand All @@ -103,13 +107,30 @@ PYBIND11_MODULE(_torchtext, m) {
.def(py::init<std::vector<std::string>, std::string>())
.def_readonly("itos_", &Vocab::itos_)
.def_readonly("unk_token_", &Vocab::unk_token_)
.def("__getitem__", &Vocab::__getitem__)
.def("__getitem__",
[](c10::intrusive_ptr<Vocab> &self, const py::str &item) -> int64_t {
ssize_t length;
const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length);
return self->__getitem__(c10::string_view{buffer, (size_t)length});
})
.def("__len__", &Vocab::__len__)
.def("insert_token", &Vocab::insert_token)
.def("append_token", &Vocab::append_token)
.def("lookup_token", &Vocab::lookup_token)
.def("lookup_tokens", &Vocab::lookup_tokens)
.def("lookup_indices", &Vocab::lookup_indices)
.def("lookup_indices",
[](const c10::intrusive_ptr<Vocab> &self, const py::list &items) {
std::vector<int64_t> indices(items.size());
int64_t counter = 0;
for (const auto &item : items) {
ssize_t length;
const char *buffer =
PyUnicode_AsUTF8AndSize(item.ptr(), &length);
indices[counter++] =
self->__getitem__(c10::string_view{buffer, (size_t)length});
}
return indices;
})
.def("get_stoi", &Vocab::get_stoi)
.def("get_itos", &Vocab::get_itos)
.def(py::pickle(
Expand All @@ -131,96 +152,112 @@ PYBIND11_MODULE(_torchtext, m) {

TORCH_LIBRARY_FRAGMENT(torchtext, m) {
m.class_<Regex>("Regex")
.def(torch::init<std::string>())
.def("Sub", &Regex::Sub)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Regex> &self) -> std::string {
return _serialize_regex(self);
},
// __setstate__
[](std::string state) -> c10::intrusive_ptr<Regex> {
return _deserialize_regex(std::move(state));
});
.def(torch::init<std::string>())
.def("Sub", &Regex::Sub)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Regex> &self) -> std::string {
return _serialize_regex(self);
},
// __setstate__
[](std::string state) -> c10::intrusive_ptr<Regex> {
return _deserialize_regex(std::move(state));
});

m.class_<RegexTokenizer>("RegexTokenizer")
.def(torch::init<std::vector<std::string>, std::vector<std::string>, bool>())
.def("forward", &RegexTokenizer::forward)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<RegexTokenizer> &self) -> RegexTokenizerStates {
return _serialize_regex_tokenizer(self);
},
// __setstate__
[](RegexTokenizerStates states) -> c10::intrusive_ptr<RegexTokenizer> {
return _deserialize_regex_tokenizer(std::move(states));
});
.def(torch::init<std::vector<std::string>, std::vector<std::string>,
bool>())
.def("forward", &RegexTokenizer::forward)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<RegexTokenizer> &self)
-> RegexTokenizerStates {
return _serialize_regex_tokenizer(self);
},
// __setstate__
[](RegexTokenizerStates states)
-> c10::intrusive_ptr<RegexTokenizer> {
return _deserialize_regex_tokenizer(std::move(states));
});

m.class_<SentencePiece>("SentencePiece")
.def(torch::init<std::string>())
.def("Encode", &SentencePiece::Encode)
.def("EncodeAsIds", &SentencePiece::EncodeAsIds)
.def("DecodeIds", &SentencePiece::DecodeIds)
.def("EncodeAsPieces", &SentencePiece::EncodeAsPieces)
.def("DecodePieces", &SentencePiece::DecodePieces)
.def("GetPieceSize", &SentencePiece::GetPieceSize)
.def("unk_id", &SentencePiece::unk_id)
.def("PieceToId", &SentencePiece::PieceToId)
.def("IdToPiece", &SentencePiece::IdToPiece)
.def_pickle(
// The underlying content of SentencePiece contains byte string,
// and returing it as std::string cause UTF8 decoding error.
// Since TorchScript does not support byte string, we use byte Tensor to
// pass around the data.
// __getstate__
[](const c10::intrusive_ptr<SentencePiece> &self) -> torch::Tensor {
auto *data = static_cast<void*>(const_cast<char*>(self->content_.data()));
auto numel = static_cast<int64_t>(self->content_.size());
return torch::from_blob(data, {numel}, {torch::kUInt8}).clone();
},
// __setstate__
[](torch::Tensor state) -> c10::intrusive_ptr<SentencePiece> {
auto *data = static_cast<char*>(state.data_ptr());
auto numel = state.size(0);
return c10::make_intrusive<SentencePiece>(std::string(data, numel));
});
.def(torch::init<std::string>())
.def("Encode", &SentencePiece::Encode)
.def("EncodeAsIds", &SentencePiece::EncodeAsIds)
.def("DecodeIds", &SentencePiece::DecodeIds)
.def("EncodeAsPieces", &SentencePiece::EncodeAsPieces)
.def("DecodePieces", &SentencePiece::DecodePieces)
.def("GetPieceSize", &SentencePiece::GetPieceSize)
.def("unk_id", &SentencePiece::unk_id)
.def("PieceToId", &SentencePiece::PieceToId)
.def("IdToPiece", &SentencePiece::IdToPiece)
.def_pickle(
// The underlying content of SentencePiece contains byte string,
// and returing it as std::string cause UTF8 decoding error.
// Since TorchScript does not support byte string, we use byte Tensor
// to pass around the data.
// __getstate__
[](const c10::intrusive_ptr<SentencePiece> &self) -> torch::Tensor {
auto *data =
static_cast<void *>(const_cast<char *>(self->content_.data()));
auto numel = static_cast<int64_t>(self->content_.size());
return torch::from_blob(data, {numel}, {torch::kUInt8}).clone();
},
// __setstate__
[](torch::Tensor state) -> c10::intrusive_ptr<SentencePiece> {
auto *data = static_cast<char *>(state.data_ptr());
auto numel = state.size(0);
return c10::make_intrusive<SentencePiece>(std::string(data, numel));
});

m.class_<Vectors>("Vectors")
.def(torch::init<std::vector<std::string>, std::vector<std::int64_t>, torch::Tensor, torch::Tensor>())
.def("__getitem__", &Vectors::__getitem__)
.def("lookup_vectors", &Vectors::lookup_vectors)
.def("__setitem__", &Vectors::__setitem__)
.def("__len__", &Vectors::__len__)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Vectors> &self) -> VectorsStates {
return _serialize_vectors(self);
},
// __setstate__
[](VectorsStates states) -> c10::intrusive_ptr<Vectors> {
return _deserialize_vectors(states);
});
.def(torch::init<std::vector<std::string>, std::vector<std::int64_t>,
torch::Tensor, torch::Tensor>())
.def("__getitem__", &Vectors::__getitem__)
.def("lookup_vectors", &Vectors::lookup_vectors)
.def("__setitem__", &Vectors::__setitem__)
.def("__len__", &Vectors::__len__)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Vectors> &self) -> VectorsStates {
return _serialize_vectors(self);
},
// __setstate__
[](VectorsStates states) -> c10::intrusive_ptr<Vectors> {
return _deserialize_vectors(states);
});

m.class_<Vocab>("Vocab")
.def(torch::init<StringList, std::string>())
.def("__getitem__", &Vocab::__getitem__)
.def("__len__", &Vocab::__len__)
.def("insert_token", &Vocab::insert_token)
.def("append_token", &Vocab::append_token)
.def("lookup_token", &Vocab::lookup_token)
.def("lookup_tokens", &Vocab::lookup_tokens)
.def("lookup_indices", &Vocab::lookup_indices)
.def("get_stoi", &Vocab::get_stoi)
.def("get_itos", &Vocab::get_itos)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Vocab> &self) -> VocabStates {
return _serialize_vocab(self);
},
// __setstate__
[](VocabStates states) -> c10::intrusive_ptr<Vocab> {
return _deserialize_vocab(states);
});
.def(torch::init<StringList, std::string>())
.def("__getitem__",
[](const c10::intrusive_ptr<Vocab> &self, const std::string &item)
-> int64_t { return self->__getitem__(c10::string_view{item}); })
.def("__len__", &Vocab::__len__)
.def("insert_token", &Vocab::insert_token)
.def("append_token", &Vocab::append_token)
.def("lookup_token", &Vocab::lookup_token)
.def("lookup_tokens", &Vocab::lookup_tokens)
.def("lookup_indices",
[](const c10::intrusive_ptr<Vocab> &self,
const std::vector<std::string> &items) {
std::vector<int64_t> indices(items.size());
int64_t counter = 0;
for (const auto &item : items) {
indices[counter++] = self->__getitem__(c10::string_view{item});
}
return indices;
})
.def("get_stoi", &Vocab::get_stoi)
.def("get_itos", &Vocab::get_itos)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<Vocab> &self) -> VocabStates {
return _serialize_vocab(self);
},
// __setstate__
[](VocabStates states) -> c10::intrusive_ptr<Vocab> {
return _deserialize_vocab(states);
});

m.def("torchtext::generate_sp_model", &generate_sp_model);
m.def("torchtext::load_sp_model", &load_sp_model);
Expand Down
Loading

0 comments on commit 257db3a

Please sign in to comment.