Skip to content

Commit

Permalink
Updated vocab and vectors with forward method (pytorch#953)
Browse files Browse the repository at this point in the history
* Updated vocab and vectors with forward method

* Added tests
  • Loading branch information
Nayef211 authored Sep 1, 2020
1 parent 87f0d44 commit 8aecbb9
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 2 deletions.
18 changes: 18 additions & 0 deletions test/experimental/test_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,24 @@ def test_vectors_jit(self):
self.assertEqual(vectors_obj['b'], jit_vectors_obj['b'])
self.assertEqual(vectors_obj['not_in_it'], jit_vectors_obj['not_in_it'])

def test_vectors_forward(self):
tensorA = torch.tensor([1, 0], dtype=torch.float)
tensorB = torch.tensor([0, 1], dtype=torch.float)

unk_tensor = torch.tensor([0, 0], dtype=torch.float)
tokens = ['a', 'b']
vecs = torch.stack((tensorA, tensorB), 0)
vectors_obj = vectors(tokens, vecs, unk_tensor=unk_tensor)
jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue())

tokens_to_lookup = ['a', 'b', 'c']
expected_vectors = torch.stack((tensorA, tensorB, unk_tensor), 0)
vectors_by_tokens = vectors_obj(tokens_to_lookup)
jit_vectors_by_tokens = jit_vectors_obj(tokens_to_lookup)

self.assertEqual(expected_vectors, vectors_by_tokens)
self.assertEqual(expected_vectors, jit_vectors_by_tokens)

def test_vectors_lookup_vectors(self):
tensorA = torch.tensor([1, 0], dtype=torch.float)
tensorB = torch.tensor([0, 1], dtype=torch.float)
Expand Down
14 changes: 14 additions & 0 deletions test/experimental/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ def test_vocab_jit(self):
self.assertEqual(jit_v.get_itos(), expected_itos)
self.assertEqual(dict(jit_v.get_stoi()), expected_stoi)

def test_vocab_forward(self):
token_to_freq = {'a': 2, 'b': 2, 'c': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)

c = OrderedDict(sorted_by_freq_tuples)
v = vocab(c)
jit_v = torch.jit.script(v.to_ivalue())

tokens = ['b', 'a', 'c']
expected_indices = [2, 1, 3]

self.assertEqual(v(tokens), expected_indices)
self.assertEqual(jit_v(tokens), expected_indices)

def test_vocab_lookup_token(self):
token_to_freq = {'a': 2, 'b': 2, 'c': 2}
sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)
Expand Down
2 changes: 1 addition & 1 deletion torchtext/experimental/vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def is_jitable(self):
return not isinstance(self.vectors, VectorsPybind)

@torch.jit.export
def __call__(self, tokens: List[str]) -> Tensor:
def forward(self, tokens: List[str]) -> Tensor:
r"""Calls the `lookup_vectors` method
Args:
tokens: a list of tokens
Expand Down
2 changes: 1 addition & 1 deletion torchtext/experimental/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def is_jitable(self):
return not isinstance(self.vocab, VocabPybind)

@torch.jit.export
def __call__(self, tokens: List[str]) -> List[int]:
def forward(self, tokens: List[str]) -> List[int]:
r"""Calls the `lookup_indices` method
Args:
tokens (List[str]): the tokens used to lookup their corresponding `indices`.
Expand Down

0 comments on commit 8aecbb9

Please sign in to comment.