diff --git a/.circleci/unittest/linux/scripts/environment.yml b/.circleci/unittest/linux/scripts/environment.yml index dbf84c006e..e616d8f107 100644 --- a/.circleci/unittest/linux/scripts/environment.yml +++ b/.circleci/unittest/linux/scripts/environment.yml @@ -18,5 +18,6 @@ dependencies: - sphinx - sphinx-rtd-theme - tqdm + - expecttest - https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.0.0/de_core_news_sm-3.0.0.tar.gz#egg=de_core_news_sm==3.0.0 - https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm==3.0.0 diff --git a/.circleci/unittest/windows/scripts/environment.yml b/.circleci/unittest/windows/scripts/environment.yml index 21833a0c52..75e6d25c13 100644 --- a/.circleci/unittest/windows/scripts/environment.yml +++ b/.circleci/unittest/windows/scripts/environment.yml @@ -21,5 +21,6 @@ dependencies: - tqdm - certifi - future + - expecttest - https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.0.0/de_core_news_sm-3.0.0.tar.gz#egg=de_core_news_sm==3.0.0 - https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm==3.0.0 diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_pybindings.cpp similarity index 56% rename from torchtext/csrc/register_bindings.cpp rename to torchtext/csrc/register_pybindings.cpp index e7e2241a7f..34e8da149e 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_pybindings.cpp @@ -9,6 +9,8 @@ #include #include // @manual #include // @manual +#include + namespace torchtext { namespace py = pybind11; @@ -155,126 +157,8 @@ PYBIND11_MODULE(_torchtext, m) { &_load_token_and_vectors_from_file); m.def("_load_vocab_from_file", &_load_vocab_from_file); m.def("_build_vocab_from_text_file", &build_vocab_from_text_file); - m.def("_build_vocab_from_text_file_using_python_tokenizer", &_build_vocab_from_text_file_using_python_tokenizer); -} - -TORCH_LIBRARY_FRAGMENT(torchtext, m) { - m.class_("Regex") - .def(torch::init()) - .def("Sub", &Regex::Sub) - .def_pickle( - // __getstate__ - [](const c10::intrusive_ptr &self) -> std::string { - return _serialize_regex(self); - }, - // __setstate__ - [](std::string state) -> c10::intrusive_ptr { - return _deserialize_regex(std::move(state)); - }); - - m.class_("RegexTokenizer") - .def(torch::init, std::vector, - bool>()) - .def("forward", &RegexTokenizer::forward) - .def_pickle( - // __getstate__ - [](const c10::intrusive_ptr &self) - -> RegexTokenizerStates { - return _serialize_regex_tokenizer(self); - }, - // __setstate__ - [](RegexTokenizerStates states) - -> c10::intrusive_ptr { - return _deserialize_regex_tokenizer(std::move(states)); - }); - - m.class_("SentencePiece") - .def(torch::init()) - .def("Encode", &SentencePiece::Encode) - .def("EncodeAsIds", &SentencePiece::EncodeAsIds) - .def("DecodeIds", &SentencePiece::DecodeIds) - .def("EncodeAsPieces", &SentencePiece::EncodeAsPieces) - .def("DecodePieces", &SentencePiece::DecodePieces) - .def("GetPieceSize", &SentencePiece::GetPieceSize) - .def("unk_id", &SentencePiece::unk_id) - .def("PieceToId", &SentencePiece::PieceToId) - .def("IdToPiece", &SentencePiece::IdToPiece) - .def_pickle( - // The underlying content of SentencePiece contains byte string, - // and returing it as std::string cause UTF8 decoding error. - // Since TorchScript does not support byte string, we use byte Tensor - // to pass around the data. - // __getstate__ - [](const c10::intrusive_ptr &self) -> torch::Tensor { - auto *data = - static_cast(const_cast(self->content_.data())); - auto numel = static_cast(self->content_.size()); - return torch::from_blob(data, {numel}, {torch::kUInt8}).clone(); - }, - // __setstate__ - [](torch::Tensor state) -> c10::intrusive_ptr { - auto *data = static_cast(state.data_ptr()); - auto numel = state.size(0); - return c10::make_intrusive(std::string(data, numel)); - }); - - m.class_("Vectors") - .def(torch::init, std::vector, - torch::Tensor, torch::Tensor>()) - .def("__getitem__", &Vectors::__getitem__) - .def("lookup_vectors", &Vectors::lookup_vectors) - .def("__setitem__", &Vectors::__setitem__) - .def("__len__", &Vectors::__len__) - .def_pickle( - // __getstate__ - [](const c10::intrusive_ptr &self) -> VectorsStates { - return _serialize_vectors(self); - }, - // __setstate__ - [](VectorsStates states) -> c10::intrusive_ptr { - return _deserialize_vectors(states); - }); - - m.class_("Vocab") - .def(torch::init>()) - .def("__contains__", - [](const c10::intrusive_ptr &self, const std::string &item) - -> bool { return self->__contains__(c10::string_view{item}); }) - .def("__getitem__", - [](const c10::intrusive_ptr &self, const std::string &item) - -> int64_t { return self->__getitem__(c10::string_view{item}); }) - .def("insert_token", &Vocab::insert_token) - .def("__len__", &Vocab::__len__) - .def("set_default_index", &Vocab::set_default_index) - .def("get_default_index", &Vocab::get_default_index) - .def("append_token", &Vocab::append_token) - .def("lookup_token", &Vocab::lookup_token) - .def("lookup_tokens", &Vocab::lookup_tokens) - .def("lookup_indices", - [](const c10::intrusive_ptr &self, - const std::vector &items) { - std::vector indices(items.size()); - int64_t counter = 0; - for (const auto &item : items) { - indices[counter++] = self->__getitem__(c10::string_view{item}); - } - return indices; - }) - .def("get_stoi", &Vocab::get_stoi) - .def("get_itos", &Vocab::get_itos) - .def_pickle( - // __getstate__ - [](const c10::intrusive_ptr &self) -> VocabStates { - return _serialize_vocab(self); - }, - // __setstate__ - [](VocabStates states) -> c10::intrusive_ptr { - return _deserialize_vocab(states); - }); - - m.def("torchtext::generate_sp_model", &generate_sp_model); - m.def("torchtext::load_sp_model", &load_sp_model); - m.def("torchtext::load_sp_model_string", &load_sp_model_string); + m.def("_build_vocab_from_text_file_using_python_tokenizer", + &_build_vocab_from_text_file_using_python_tokenizer); } } // namespace torchtext diff --git a/torchtext/csrc/register_torchbindings.cpp b/torchtext/csrc/register_torchbindings.cpp new file mode 100644 index 0000000000..f701119601 --- /dev/null +++ b/torchtext/csrc/register_torchbindings.cpp @@ -0,0 +1,129 @@ +#include +#include +#include // @manual +#include // @manual +#include +#include // @manual +#include // @manual +namespace torchtext { + +TORCH_LIBRARY_FRAGMENT(torchtext, m) { + m.class_("Regex") + .def(torch::init()) + .def("Sub", &Regex::Sub) + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr &self) -> std::string { + return _serialize_regex(self); + }, + // __setstate__ + [](std::string state) -> c10::intrusive_ptr { + return _deserialize_regex(std::move(state)); + }); + + m.class_("RegexTokenizer") + .def(torch::init, std::vector, + bool>()) + .def("forward", &RegexTokenizer::forward) + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr &self) + -> RegexTokenizerStates { + return _serialize_regex_tokenizer(self); + }, + // __setstate__ + [](RegexTokenizerStates states) + -> c10::intrusive_ptr { + return _deserialize_regex_tokenizer(std::move(states)); + }); + + m.class_("SentencePiece") + .def(torch::init()) + .def("Encode", &SentencePiece::Encode) + .def("EncodeAsIds", &SentencePiece::EncodeAsIds) + .def("DecodeIds", &SentencePiece::DecodeIds) + .def("EncodeAsPieces", &SentencePiece::EncodeAsPieces) + .def("DecodePieces", &SentencePiece::DecodePieces) + .def("GetPieceSize", &SentencePiece::GetPieceSize) + .def("unk_id", &SentencePiece::unk_id) + .def("PieceToId", &SentencePiece::PieceToId) + .def("IdToPiece", &SentencePiece::IdToPiece) + .def_pickle( + // The underlying content of SentencePiece contains byte string, + // and returing it as std::string cause UTF8 decoding error. + // Since TorchScript does not support byte string, we use byte Tensor + // to pass around the data. + // __getstate__ + [](const c10::intrusive_ptr &self) -> torch::Tensor { + auto *data = + static_cast(const_cast(self->content_.data())); + auto numel = static_cast(self->content_.size()); + return torch::from_blob(data, {numel}, {torch::kUInt8}).clone(); + }, + // __setstate__ + [](torch::Tensor state) -> c10::intrusive_ptr { + auto *data = static_cast(state.data_ptr()); + auto numel = state.size(0); + return c10::make_intrusive(std::string(data, numel)); + }); + + m.class_("Vectors") + .def(torch::init, std::vector, + torch::Tensor, torch::Tensor>()) + .def("__getitem__", &Vectors::__getitem__) + .def("lookup_vectors", &Vectors::lookup_vectors) + .def("__setitem__", &Vectors::__setitem__) + .def("__len__", &Vectors::__len__) + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr &self) -> VectorsStates { + return _serialize_vectors(self); + }, + // __setstate__ + [](VectorsStates states) -> c10::intrusive_ptr { + return _deserialize_vectors(states); + }); + + m.class_("Vocab") + .def(torch::init>()) + .def("__contains__", + [](const c10::intrusive_ptr &self, const std::string &item) + -> bool { return self->__contains__(c10::string_view{item}); }) + .def("__getitem__", + [](const c10::intrusive_ptr &self, const std::string &item) + -> int64_t { return self->__getitem__(c10::string_view{item}); }) + .def("insert_token", &Vocab::insert_token) + .def("__len__", &Vocab::__len__) + .def("set_default_index", &Vocab::set_default_index) + .def("get_default_index", &Vocab::get_default_index) + .def("append_token", &Vocab::append_token) + .def("lookup_token", &Vocab::lookup_token) + .def("lookup_tokens", &Vocab::lookup_tokens) + .def("lookup_indices", + [](const c10::intrusive_ptr &self, + const std::vector &items) { + std::vector indices(items.size()); + int64_t counter = 0; + for (const auto &item : items) { + indices[counter++] = self->__getitem__(c10::string_view{item}); + } + return indices; + }) + .def("get_stoi", &Vocab::get_stoi) + .def("get_itos", &Vocab::get_itos) + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr &self) -> VocabStates { + return _serialize_vocab(self); + }, + // __setstate__ + [](VocabStates states) -> c10::intrusive_ptr { + return _deserialize_vocab(states); + }); + + m.def("torchtext::generate_sp_model", &generate_sp_model); + m.def("torchtext::load_sp_model", &load_sp_model); + m.def("torchtext::load_sp_model_string", &load_sp_model_string); +} + +} // namespace torchtext diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 1ce67ed8b5..c83652c014 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -191,17 +191,6 @@ void parse_raw_text_file_chunk(const std::string &file_path, size_t offset, } } -// sorting using a custom object -struct CompareTokens { - bool operator()(const std::pair &a, - const std::pair &b) { - if (a.second == b.second) { - return a.first < b.first; - } - return a.second > b.second; - } -}; - StringList _concat_tokens(std::vector> chunk_counters, const int64_t min_freq, const int64_t num_lines, @@ -345,54 +334,6 @@ Vocab _build_vocab_from_text_file(const std::string &file_path, return Vocab(std::move(tokens)); } -Vocab _build_vocab_from_text_file_using_python_tokenizer( - const std::string &file_path, const int64_t min_freq, - py::object tokenizer) { - // find number of lines - int64_t num_lines = _infer_lines(file_path); - // Read text from file and add tokens - std::ifstream fin(file_path, std::ios::in); - TORCH_CHECK(fin.is_open(), "Cannot open input file " + file_path); - - IndexDict counter; - std::string line; - for (int64_t i = 0; i < num_lines; i++) { - std::getline(fin, line); - std::vector token_list = - tokenizer(line).cast>(); - - for (size_t i = 0; i < token_list.size(); i++) { - std::string token = token_list[i]; - - if (counter.find(token) == counter.end()) { - counter[token] = 1; - } else { - counter[token] += 1; - } - } - } - - // create tokens-frequency pairs - std::vector> token_freq_pairs; - for (const auto &item : counter) { - if (item.second >= min_freq) { - token_freq_pairs.push_back(item); - } - } - - // sort tokens by frequency - CompareTokens compare_tokens; - std::sort(token_freq_pairs.begin(), token_freq_pairs.end(), compare_tokens); - - // Create final list of tokens - StringList tokens; - for (const auto &token_freq_pair : token_freq_pairs) { - tokens.push_back(token_freq_pair.first); - } - - return Vocab(std::move(tokens)); -} - VocabStates _serialize_vocab(const c10::intrusive_ptr &self) { std::vector integers; StringList strings = self->itos_; diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index e7df276391..66f04aa1a0 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -1,10 +1,8 @@ +#pragma once #include #include -#include #include -namespace py = pybind11; - namespace torchtext { typedef std::vector StringList; @@ -14,6 +12,19 @@ typedef std::tuple, std::vector, std::vector> VocabStates; +// sorting using a custom object +struct CompareTokens { + bool operator()(const std::pair &a, + const std::pair &b) { + if (a.second == b.second) { + return a.first < b.first; + } + return a.second > b.second; + } +}; + +int64_t _infer_lines(const std::string &file_path); + struct Vocab : torch::CustomClassHolder { static const int32_t MAX_VOCAB_SIZE = 30000000; int64_t unk_index_; @@ -79,7 +90,4 @@ Vocab _build_vocab_from_text_file(const std::string &file_path, const int64_t min_freq, const int64_t num_cpus, torch::jit::script::Module tokenizer); -Vocab _build_vocab_from_text_file_using_python_tokenizer( - const std::string &file_path, const int64_t min_freq, py::object tokenizer); - } // namespace torchtext diff --git a/torchtext/csrc/vocab_factory.h b/torchtext/csrc/vocab_factory.h new file mode 100644 index 0000000000..3cdb6692bb --- /dev/null +++ b/torchtext/csrc/vocab_factory.h @@ -0,0 +1,55 @@ +#include +#include + +namespace py = pybind11; + +namespace torchtext { + +Vocab _build_vocab_from_text_file_using_python_tokenizer( + const std::string &file_path, const int64_t min_freq, + py::object tokenizer) { + // find number of lines + int64_t num_lines = _infer_lines(file_path); + // Read text from file and add tokens + std::ifstream fin(file_path, std::ios::in); + TORCH_CHECK(fin.is_open(), "Cannot open input file " + file_path); + + IndexDict counter; + std::string line; + for (int64_t i = 0; i < num_lines; i++) { + std::getline(fin, line); + std::vector token_list = + tokenizer(line).cast>(); + + for (size_t i = 0; i < token_list.size(); i++) { + std::string token = token_list[i]; + + if (counter.find(token) == counter.end()) { + counter[token] = 1; + } else { + counter[token] += 1; + } + } + } + + // create tokens-frequency pairs + std::vector> token_freq_pairs; + for (const auto &item : counter) { + if (item.second >= min_freq) { + token_freq_pairs.push_back(item); + } + } + + // sort tokens by frequency + CompareTokens compare_tokens; + std::sort(token_freq_pairs.begin(), token_freq_pairs.end(), compare_tokens); + + // Create final list of tokens + StringList tokens; + for (const auto &token_freq_pair : token_freq_pairs) { + tokens.push_back(token_freq_pair.first); + } + + return Vocab(std::move(tokens)); +} +} // namespace torchtext diff --git a/torchtext/experimental/vocab_factory.py b/torchtext/experimental/vocab_factory.py index bf39b87419..a157b4ba8b 100644 --- a/torchtext/experimental/vocab_factory.py +++ b/torchtext/experimental/vocab_factory.py @@ -27,7 +27,7 @@ def build_vocab_from_text_file(file_path: str, tokenizer: Optional[Callable] = N Returns: torchtext.vocab.Vocab: a `Vocab` object. Examples: - >>> from torchtext.vocab import build_vocab_from_text_file + >>> from torchtext.experimental.vocab_factory import build_vocab_from_text_file >>> v = build_vocab_from_text_file('vocab.txt') # using python split function as tokenizer >>> #using JIT'd tokenizer >>> from torchtext.experimental.transforms import basic_english_normalize diff --git a/torchtext/vocab/__init__.py b/torchtext/vocab/__init__.py new file mode 100644 index 0000000000..3b93a015bb --- /dev/null +++ b/torchtext/vocab/__init__.py @@ -0,0 +1,23 @@ +from .vocab import Vocab + +from .vectors import ( + GloVe, + FastText, + CharNGram, + pretrained_aliases, + Vectors, +) + +from .vocab_factory import ( + vocab, + build_vocab_from_iterator, +) + +__all__ = ["Vocab", + "vocab", + "build_vocab_from_iterator", + "GloVe", + "FastText", + "CharNGram", + "pretrained_aliases", + "Vectors"] diff --git a/torchtext/vocab.py b/torchtext/vocab/vectors.py old mode 100755 new mode 100644 similarity index 56% rename from torchtext/vocab.py rename to torchtext/vocab/vectors.py index 995e301869..bec350b9a1 --- a/torchtext/vocab.py +++ b/torchtext/vocab/vectors.py @@ -1,283 +1,17 @@ -from functools import partial +import torch import logging import os import zipfile import gzip -import torch -import torch.nn as nn from urllib.request import urlretrieve from tqdm import tqdm import tarfile -from typing import Dict, List, Optional, Iterable -from collections import Counter, OrderedDict -from torchtext._torchtext import ( - Vocab as VocabPybind, -) -from .utils import reporthook - -logger = logging.getLogger(__name__) - -__all__ = [ - 'build_vocab_from_iterator', - 'vocab', -] +from functools import partial +from ..utils import reporthook logger = logging.getLogger(__name__) -class Vocab(nn.Module): - __jit_unused_properties__ = ["is_jitable"] - r"""Creates a vocab object which maps tokens to indices. - - Args: - vocab (torch.classes.torchtext.Vocab or torchtext._torchtext.Vocab): a cpp vocab object. - """ - - def __init__(self, vocab): - super(Vocab, self).__init__() - self.vocab = vocab - torch._C._log_api_usage_once(f"torchtext.{self.__class__.__name__}") - - @property - def is_jitable(self): - return not isinstance(self.vocab, VocabPybind) - - @torch.jit.export - def forward(self, tokens: List[str]) -> List[int]: - r"""Calls the `lookup_indices` method - - Args: - tokens: a list of tokens used to lookup their corresponding `indices`. - - Returns: - The indices associated with a list of `tokens`. - """ - return self.vocab.lookup_indices(tokens) - - @torch.jit.export - def __len__(self) -> int: - r""" - Returns: - The length of the vocab. - """ - return len(self.vocab) - - @torch.jit.export - def __contains__(self, token: str) -> bool: - r""" - Args: - token: The token for which to check the membership. - - Returns: - Whether the token is member of vocab or not. - """ - return self.vocab.__contains__(token) - - @torch.jit.export - def __getitem__(self, token: str) -> int: - r""" - Args: - token: The token used to lookup the corresponding index. - - Returns: - The index corresponding to the associated token. - """ - return self.vocab[token] - - @torch.jit.export - def set_default_index(self, index: Optional[int]) -> None: - r""" - Args: - index: Value of default index. This index will be returned when OOV token is queried. - """ - self.vocab.set_default_index(index) - - @torch.jit.export - def get_default_index(self) -> Optional[int]: - r""" - Returns: - Value of default index if it is set. - """ - return self.vocab.get_default_index() - - @torch.jit.export - def insert_token(self, token: str, index: int) -> None: - r""" - Args: - token: The token used to lookup the corresponding index. - index: The index corresponding to the associated token. - Raises: - RuntimeError: If `index` is not in range [0, Vocab.size()] or if `token` already exists in the vocab. - """ - self.vocab.insert_token(token, index) - - @torch.jit.export - def append_token(self, token: str) -> None: - r""" - Args: - token: The token used to lookup the corresponding index. - - Raises: - RuntimeError: If `token` already exists in the vocab - """ - self.vocab.append_token(token) - - @torch.jit.export - def lookup_token(self, index: int) -> str: - r""" - Args: - index: The index corresponding to the associated token. - - Returns: - token: The token used to lookup the corresponding index. - - Raises: - RuntimeError: If `index` not in range [0, itos.size()). - """ - return self.vocab.lookup_token(index) - - @torch.jit.export - def lookup_tokens(self, indices: List[int]) -> List[str]: - r""" - Args: - indices: The `indices` used to lookup their corresponding`tokens`. - - Returns: - The `tokens` associated with `indices`. - - Raises: - RuntimeError: If an index within `indices` is not int range [0, itos.size()). - """ - return self.vocab.lookup_tokens(indices) - - @torch.jit.export - def lookup_indices(self, tokens: List[str]) -> List[int]: - r""" - Args: - tokens: the tokens used to lookup their corresponding `indices`. - - Returns: - The 'indices` associated with `tokens`. - """ - return self.vocab.lookup_indices(tokens) - - @torch.jit.export - def get_stoi(self) -> Dict[str, int]: - r""" - Returns: - Dictionary mapping tokens to indices. - """ - return self.vocab.get_stoi() - - @torch.jit.export - def get_itos(self) -> List[str]: - r""" - Returns: - List mapping indices to tokens. - """ - return self.vocab.get_itos() - - def __prepare_scriptable__(self): - r"""Return a JITable Vocab. - """ - if not self.is_jitable: - cpp_vocab = torch.classes.torchtext.Vocab(self.vocab.itos_, self.vocab.default_index_) - return Vocab(cpp_vocab) - return self - - -def vocab(ordered_dict: Dict, min_freq: int = 1) -> Vocab: - r"""Factory method for creating a vocab object which maps tokens to indices. - - Note that the ordering in which key value pairs were inserted in the `ordered_dict` will be respected when building the vocab. - Therefore if sorting by token frequency is important to the user, the `ordered_dict` should be created in a way to reflect this. - - Args: - ordered_dict: Ordered Dictionary mapping tokens to their corresponding occurance frequencies. - min_freq: The minimum frequency needed to include a token in the vocabulary. - - Returns: - torchtext.vocab.Vocab: A `Vocab` object - - Examples: - >>> from torchtext.vocab import vocab - >>> from collections import Counter, OrderedDict - >>> counter = Counter(["a", "a", "b", "b", "b"]) - >>> sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True) - >>> ordered_dict = OrderedDict(sorted_by_freq_tuples) - >>> v1 = vocab(ordered_dict) - >>> print(v1['a']) #prints 1 - >>> print(v1['out of vocab']) #raise RuntimeError since default index is not set - >>> tokens = ['e', 'd', 'c', 'b', 'a'] - >>> v2 = vocab(OrderedDict([(token, 1) for token in tokens])) - >>> #adding token and default index - >>> unk_token = '' - >>> default_index = -1 - >>> if unk_token not in v2: v2.insert_token(unk_token, 0) - >>> v2.set_default_index(default_index) - >>> print(v2['']) #prints 0 - >>> print(v2['out of vocab']) #prints -1 - >>> #make default index same as index of unk_token - >>> v2.set_default_index(v2[unk_token]) - >>> v2['out of vocab'] is v2[unk_token] #prints True - """ - - tokens = [] - for token, freq in ordered_dict.items(): - if freq >= min_freq: - tokens.append(token) - - return Vocab(VocabPybind(tokens, None)) - - -def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: Optional[List[str]] = None, special_first: bool = True) -> Vocab: - """ - Build a Vocab from an iterator. - - Args: - iterator: Iterator used to build Vocab. Must yield list or iterator of tokens. - min_freq: The minimum frequency needed to include a token in the vocabulary. - specials: Special symbols to add. The order of supplied tokens will be preserved. - special_first: Indicates whether to insert symbols at the beginning or at the end. - - - Returns: - torchtext.vocab.Vocab: A `Vocab` object - - Examples: - >>> #generating vocab from text file - >>> import io - >>> from torchtext.vocab import build_vocab_from_iterator - >>> def yield_tokens(file_path): - >>> with io.open(file_path, encoding = 'utf-8') as f: - >>> for line in f: - >>> yield line.strip().split() - >>> vocab = build_vocab_from_iterator(yield_tokens_batch(file_path), specials=[""]) - """ - - counter = Counter() - for tokens in iterator: - counter.update(tokens) - - if specials is not None: - for tok in specials: - del counter[tok] - - sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[0]) - sorted_by_freq_tuples.sort(key=lambda x: x[1], reverse=True) - ordered_dict = OrderedDict(sorted_by_freq_tuples) - - if specials is not None: - if special_first: - specials = specials[::-1] - for symbol in specials: - ordered_dict.update({symbol: min_freq}) - ordered_dict.move_to_end(symbol, last=not special_first) - - word_vocab = vocab(ordered_dict, min_freq=min_freq) - return word_vocab - - def _infer_shape(f): num_lines, vector_dim = 0, None for line in f: diff --git a/torchtext/vocab/vocab.py b/torchtext/vocab/vocab.py new file mode 100755 index 0000000000..64c4f52435 --- /dev/null +++ b/torchtext/vocab/vocab.py @@ -0,0 +1,164 @@ +import torch +import torch.nn as nn +from typing import Dict, List, Optional + + +class Vocab(nn.Module): + __jit_unused_properties__ = ["is_jitable"] + r"""Creates a vocab object which maps tokens to indices. + + Args: + vocab (torch.classes.torchtext.Vocab or torchtext._torchtext.Vocab): a cpp vocab object. + """ + + def __init__(self, vocab): + super(Vocab, self).__init__() + self.vocab = vocab + torch._C._log_api_usage_once(f"torchtext.{self.__class__.__name__}") + + @property + def is_jitable(self): + return isinstance(self.vocab, torch._C.ScriptObject) + + @torch.jit.export + def forward(self, tokens: List[str]) -> List[int]: + r"""Calls the `lookup_indices` method + + Args: + tokens: a list of tokens used to lookup their corresponding `indices`. + + Returns: + The indices associated with a list of `tokens`. + """ + return self.vocab.lookup_indices(tokens) + + @torch.jit.export + def __len__(self) -> int: + r""" + Returns: + The length of the vocab. + """ + return len(self.vocab) + + @torch.jit.export + def __contains__(self, token: str) -> bool: + r""" + Args: + token: The token for which to check the membership. + + Returns: + Whether the token is member of vocab or not. + """ + return self.vocab.__contains__(token) + + @torch.jit.export + def __getitem__(self, token: str) -> int: + r""" + Args: + token: The token used to lookup the corresponding index. + + Returns: + The index corresponding to the associated token. + """ + return self.vocab[token] + + @torch.jit.export + def set_default_index(self, index: Optional[int]) -> None: + r""" + Args: + index: Value of default index. This index will be returned when OOV token is queried. + """ + self.vocab.set_default_index(index) + + @torch.jit.export + def get_default_index(self) -> Optional[int]: + r""" + Returns: + Value of default index if it is set. + """ + return self.vocab.get_default_index() + + @torch.jit.export + def insert_token(self, token: str, index: int) -> None: + r""" + Args: + token: The token used to lookup the corresponding index. + index: The index corresponding to the associated token. + Raises: + RuntimeError: If `index` is not in range [0, Vocab.size()] or if `token` already exists in the vocab. + """ + self.vocab.insert_token(token, index) + + @torch.jit.export + def append_token(self, token: str) -> None: + r""" + Args: + token: The token used to lookup the corresponding index. + + Raises: + RuntimeError: If `token` already exists in the vocab + """ + self.vocab.append_token(token) + + @torch.jit.export + def lookup_token(self, index: int) -> str: + r""" + Args: + index: The index corresponding to the associated token. + + Returns: + token: The token used to lookup the corresponding index. + + Raises: + RuntimeError: If `index` not in range [0, itos.size()). + """ + return self.vocab.lookup_token(index) + + @torch.jit.export + def lookup_tokens(self, indices: List[int]) -> List[str]: + r""" + Args: + indices: The `indices` used to lookup their corresponding`tokens`. + + Returns: + The `tokens` associated with `indices`. + + Raises: + RuntimeError: If an index within `indices` is not int range [0, itos.size()). + """ + return self.vocab.lookup_tokens(indices) + + @torch.jit.export + def lookup_indices(self, tokens: List[str]) -> List[int]: + r""" + Args: + tokens: the tokens used to lookup their corresponding `indices`. + + Returns: + The 'indices` associated with `tokens`. + """ + return self.vocab.lookup_indices(tokens) + + @torch.jit.export + def get_stoi(self) -> Dict[str, int]: + r""" + Returns: + Dictionary mapping tokens to indices. + """ + return self.vocab.get_stoi() + + @torch.jit.export + def get_itos(self) -> List[str]: + r""" + Returns: + List mapping indices to tokens. + """ + return self.vocab.get_itos() + + def __prepare_scriptable__(self): + r"""Return a JITable Vocab. + """ + if not self.is_jitable: + cpp_vocab = torch.classes.torchtext.Vocab(self.vocab.itos_, self.vocab.default_index_) + return Vocab(cpp_vocab) + return self diff --git a/torchtext/vocab/vocab_factory.py b/torchtext/vocab/vocab_factory.py new file mode 100644 index 0000000000..bdea76f0a6 --- /dev/null +++ b/torchtext/vocab/vocab_factory.py @@ -0,0 +1,98 @@ +from .vocab import Vocab +from typing import Dict, Iterable, Optional, List +from collections import Counter, OrderedDict +from torchtext._torchtext import ( + Vocab as VocabPybind, +) + + +def vocab(ordered_dict: Dict, min_freq: int = 1) -> Vocab: + r"""Factory method for creating a vocab object which maps tokens to indices. + + Note that the ordering in which key value pairs were inserted in the `ordered_dict` will be respected when building the vocab. + Therefore if sorting by token frequency is important to the user, the `ordered_dict` should be created in a way to reflect this. + + Args: + ordered_dict: Ordered Dictionary mapping tokens to their corresponding occurance frequencies. + min_freq: The minimum frequency needed to include a token in the vocabulary. + + Returns: + torchtext.vocab.Vocab: A `Vocab` object + + Examples: + >>> from torchtext.vocab import vocab + >>> from collections import Counter, OrderedDict + >>> counter = Counter(["a", "a", "b", "b", "b"]) + >>> sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True) + >>> ordered_dict = OrderedDict(sorted_by_freq_tuples) + >>> v1 = vocab(ordered_dict) + >>> print(v1['a']) #prints 1 + >>> print(v1['out of vocab']) #raise RuntimeError since default index is not set + >>> tokens = ['e', 'd', 'c', 'b', 'a'] + >>> v2 = vocab(OrderedDict([(token, 1) for token in tokens])) + >>> #adding token and default index + >>> unk_token = '' + >>> default_index = -1 + >>> if unk_token not in v2: v2.insert_token(unk_token, 0) + >>> v2.set_default_index(default_index) + >>> print(v2['']) #prints 0 + >>> print(v2['out of vocab']) #prints -1 + >>> #make default index same as index of unk_token + >>> v2.set_default_index(v2[unk_token]) + >>> v2['out of vocab'] is v2[unk_token] #prints True + """ + + tokens = [] + for token, freq in ordered_dict.items(): + if freq >= min_freq: + tokens.append(token) + + return Vocab(VocabPybind(tokens, None)) + + +def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: Optional[List[str]] = None, special_first: bool = True) -> Vocab: + """ + Build a Vocab from an iterator. + + Args: + iterator: Iterator used to build Vocab. Must yield list or iterator of tokens. + min_freq: The minimum frequency needed to include a token in the vocabulary. + specials: Special symbols to add. The order of supplied tokens will be preserved. + special_first: Indicates whether to insert symbols at the beginning or at the end. + + + Returns: + torchtext.vocab.Vocab: A `Vocab` object + + Examples: + >>> #generating vocab from text file + >>> import io + >>> from torchtext.vocab import build_vocab_from_iterator + >>> def yield_tokens(file_path): + >>> with io.open(file_path, encoding = 'utf-8') as f: + >>> for line in f: + >>> yield line.strip().split() + >>> vocab = build_vocab_from_iterator(yield_tokens_batch(file_path), specials=[""]) + """ + + counter = Counter() + for tokens in iterator: + counter.update(tokens) + + if specials is not None: + for tok in specials: + del counter[tok] + + sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[0]) + sorted_by_freq_tuples.sort(key=lambda x: x[1], reverse=True) + ordered_dict = OrderedDict(sorted_by_freq_tuples) + + if specials is not None: + if special_first: + specials = specials[::-1] + for symbol in specials: + ordered_dict.update({symbol: min_freq}) + ordered_dict.move_to_end(symbol, last=not special_first) + + word_vocab = vocab(ordered_dict, min_freq=min_freq) + return word_vocab