Skip to content

Commit

Permalink
Added APIs for default index and removed unk token (#1302)
Browse files Browse the repository at this point in the history
  • Loading branch information
parmeet committed May 12, 2021
1 parent 8f5267c commit 54833bd
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 184 deletions.
79 changes: 53 additions & 26 deletions test/experimental/test_vocab.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# -*- coding: utf-8 -*-
from collections import OrderedDict
import os
import platform
import torch
import unittest
from test.common.torchtext_test_case import TorchtextTestCase
from torchtext.experimental.vocab import (
vocab,
Expand All @@ -20,18 +18,12 @@ def tearDown(self):
def test_has_unk(self):
c = OrderedDict()
v = vocab(c)

# check if unk is mapped to the first index
self.assertEqual(v['not_in_it'], 0)
self.assertEqual(v['<unk>'], 0)

def test_new_unk(self):
c = OrderedDict()
v = vocab(c, unk_token="<new_unk>")

# check if new_unk is mapped to the first index
self.assertEqual(v['<new_unk>'], 0)
self.assertEqual(v['not_in_it'], 0)

def test_vocab_membership(self):
token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2}
Expand All @@ -54,6 +46,50 @@ def test_vocab_get_item(self):
self.assertEqual(v['a'], 1)
self.assertEqual(v['b'], 2)

def test_reassign_token(self):
token_to_freq = {'<unk>': 1, 'a': 2, 'b': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
c = OrderedDict(sorted_by_freq_tuples)
v = vocab(c, min_freq=1)

self.assertEqual(v['<unk>'], 2)
self.assertEqual(v['a'], 0)
self.assertEqual(v['b'], 1)
v.reassign_token('<unk>', 0)
self.assertEqual(v['<unk>'], 0)
self.assertEqual(v['a'], 1)
self.assertEqual(v['b'], 2)

self.assertEqual(v.get_itos(), ['<unk>', 'a', 'b'])

with self.assertRaises(RuntimeError):
v.reassign_token('not in vocab', 0)

with self.assertRaises(RuntimeError):
v.reassign_token('<unk>', 3)

def test_default_index(self):
token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
c = OrderedDict(sorted_by_freq_tuples)
v = vocab(c, min_freq=2)

self.assertTrue(v.get_default_index() is None)
with self.assertRaises(RuntimeError):
v['not in vocab']

v.set_default_index(0)
self.assertEqual(v['not in vocab'], 0)

def test_default_index_jit(self):
token_to_freq = {'<unk>': 2, 'a': 2, 'b': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
c = OrderedDict(sorted_by_freq_tuples)
v = vocab(c, min_freq=2)
v.set_default_index(0)
v_jit = torch.jit.script(v)
self.assertEqual(v_jit['not in vocab'], 0)

def test_vocab_insert_token(self):
c = OrderedDict({'<unk>': 2, 'a': 2})

Expand Down Expand Up @@ -88,6 +124,10 @@ def test_vocab_append_token(self):
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)

# token must not exist to be appended
with self.assertRaises(RuntimeError):
v.append_token('b')

def test_vocab_len(self):
token_to_freq = {'a': 2, 'b': 2, 'c': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
Expand Down Expand Up @@ -149,6 +189,8 @@ def test_vocab_lookup_token(self):
v = vocab(c)

self.assertEqual(v.lookup_token(1), 'a')
with self.assertRaises(RuntimeError):
v.lookup_token(100)

def test_vocab_lookup_tokens(self):
token_to_freq = {'a': 2, 'b': 2, 'c': 2}
Expand All @@ -172,24 +214,6 @@ def test_vocab_lookup_indices(self):

self.assertEqual(v.lookup_indices(tokens), expected_indices)

# we separate out these errors because Windows runs into seg faults when propagating
# exceptions from C++ using pybind11
@unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.")
def test_errors_vocab_cpp(self):
token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
c = OrderedDict(sorted_by_freq_tuples)

with self.assertRaises(RuntimeError):
# Test proper error raised when setting a token out of bounds
v = vocab(c, min_freq=3)
v.insert_token('new_token', 100)

with self.assertRaises(RuntimeError):
# Test proper error raised when looking up a token out of bounds
v = vocab(c)
v.lookup_token(100)

def test_errors_vocab_python(self):
token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
Expand All @@ -205,6 +229,7 @@ def test_vocab_load_and_save(self):

c = OrderedDict(sorted_by_freq_tuples)
v = vocab(c, min_freq=3)
v.set_default_index(0)

expected_itos = ['<unk>', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
Expand All @@ -218,6 +243,7 @@ def test_vocab_load_and_save(self):
loaded_v = torch.load(vocab_path)
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
self.assertEqual(v['not in vocab'], 0)

with self.subTest('torchscript'):
vocab_path = os.path.join(self.test_dir, 'vocab_torchscript.pt')
Expand All @@ -227,6 +253,7 @@ def test_vocab_load_and_save(self):
loaded_v = torch.load(vocab_path)
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
self.assertEqual(v['not in vocab'], 0)

def test_build_vocab_iterator(self):
iterator = [['hello', 'hello', 'hello', 'freq_low', 'hello', 'world', 'world', 'world', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T',
Expand Down
33 changes: 19 additions & 14 deletions torchtext/csrc/register_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@ namespace py = pybind11;

namespace {
Vocab build_vocab_from_text_file(const std::string &file_path,
const std::string &unk_token,
const int64_t min_freq, const int64_t num_cpus,
py::object fn) {
torch::jit::script::Module module(*torch::jit::as_module(fn));
return _build_vocab_from_text_file(file_path, unk_token, min_freq, num_cpus,
module);
return _build_vocab_from_text_file(file_path, min_freq, num_cpus, module);
}
} // namespace

Expand Down Expand Up @@ -104,23 +102,27 @@ PYBIND11_MODULE(_torchtext, m) {
}));

py::class_<Vocab, c10::intrusive_ptr<Vocab>>(m, "Vocab")
.def(py::init<std::vector<std::string>, std::string>())
.def(py::init<StringList, c10::optional<int64_t>>())
.def_readonly("itos_", &Vocab::itos_)
.def_readonly("unk_token_", &Vocab::unk_token_)
.def("__contains__",
[](c10::intrusive_ptr<Vocab> &self, const py::str &item) -> bool {
ssize_t length;
const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length);
return self->__contains__(c10::string_view{buffer, (size_t)length});
})
.def_readonly("default_index_", &Vocab::default_index_)
.def(
"__contains__",
[](c10::intrusive_ptr<Vocab> &self, const py::str &item) -> bool {
ssize_t length;
const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length);
return self->__contains__(c10::string_view{buffer, (size_t)length});
})
.def("__getitem__",
[](c10::intrusive_ptr<Vocab> &self, const py::str &item) -> int64_t {
ssize_t length;
const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length);
return self->__getitem__(c10::string_view{buffer, (size_t)length});
})
.def("__len__", &Vocab::__len__)
.def("reassign_token", &Vocab::reassign_token)
.def("insert_token", &Vocab::insert_token)
.def("set_default_index", &Vocab::set_default_index)
.def("get_default_index", &Vocab::get_default_index)
.def("__len__", &Vocab::__len__)
.def("append_token", &Vocab::append_token)
.def("lookup_token", &Vocab::lookup_token)
.def("lookup_tokens", &Vocab::lookup_tokens)
Expand Down Expand Up @@ -234,15 +236,18 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) {
});

m.class_<Vocab>("Vocab")
.def(torch::init<StringList, std::string>())
.def(torch::init<StringList, c10::optional<int64_t>>())
.def("__contains__",
[](const c10::intrusive_ptr<Vocab> &self, const std::string &item)
-> bool { return self->__contains__(c10::string_view{item}); })
.def("__getitem__",
[](const c10::intrusive_ptr<Vocab> &self, const std::string &item)
-> int64_t { return self->__getitem__(c10::string_view{item}); })
.def("__len__", &Vocab::__len__)
.def("reassign_token", &Vocab::reassign_token)
.def("insert_token", &Vocab::insert_token)
.def("__len__", &Vocab::__len__)
.def("set_default_index", &Vocab::set_default_index)
.def("get_default_index", &Vocab::get_default_index)
.def("append_token", &Vocab::append_token)
.def("lookup_token", &Vocab::lookup_token)
.def("lookup_tokens", &Vocab::lookup_tokens)
Expand Down
Loading

0 comments on commit 54833bd

Please sign in to comment.