diff --git a/test/experimental/test_vectors.py b/test/experimental/test_vectors.py index d09749388c..44cea07dbf 100644 --- a/test/experimental/test_vectors.py +++ b/test/experimental/test_vectors.py @@ -71,6 +71,24 @@ def test_vectors_jit(self): self.assertEqual(vectors_obj['b'], jit_vectors_obj['b']) self.assertEqual(vectors_obj['not_in_it'], jit_vectors_obj['not_in_it']) + def test_vectors_forward(self): + tensorA = torch.tensor([1, 0], dtype=torch.float) + tensorB = torch.tensor([0, 1], dtype=torch.float) + + unk_tensor = torch.tensor([0, 0], dtype=torch.float) + tokens = ['a', 'b'] + vecs = torch.stack((tensorA, tensorB), 0) + vectors_obj = vectors(tokens, vecs, unk_tensor=unk_tensor) + jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue()) + + tokens_to_lookup = ['a', 'b', 'c'] + expected_vectors = torch.stack((tensorA, tensorB, unk_tensor), 0) + vectors_by_tokens = vectors_obj(tokens_to_lookup) + jit_vectors_by_tokens = jit_vectors_obj(tokens_to_lookup) + + self.assertEqual(expected_vectors, vectors_by_tokens) + self.assertEqual(expected_vectors, jit_vectors_by_tokens) + def test_vectors_lookup_vectors(self): tensorA = torch.tensor([1, 0], dtype=torch.float) tensorB = torch.tensor([0, 1], dtype=torch.float) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 814ca631f9..cbd56ef2e3 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -119,6 +119,20 @@ def test_vocab_jit(self): self.assertEqual(jit_v.get_itos(), expected_itos) self.assertEqual(dict(jit_v.get_stoi()), expected_stoi) + def test_vocab_forward(self): + token_to_freq = {'a': 2, 'b': 2, 'c': 2} + sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) + + c = OrderedDict(sorted_by_freq_tuples) + v = vocab(c) + jit_v = torch.jit.script(v.to_ivalue()) + + tokens = ['b', 'a', 'c'] + expected_indices = [2, 1, 3] + + self.assertEqual(v(tokens), expected_indices) + self.assertEqual(jit_v(tokens), expected_indices) + def test_vocab_lookup_token(self): token_to_freq = {'a': 2, 'b': 2, 'c': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) diff --git a/torchtext/experimental/vectors.py b/torchtext/experimental/vectors.py index 45f1adc312..7fc3f120df 100644 --- a/torchtext/experimental/vectors.py +++ b/torchtext/experimental/vectors.py @@ -203,7 +203,7 @@ def is_jitable(self): return not isinstance(self.vectors, VectorsPybind) @torch.jit.export - def __call__(self, tokens: List[str]) -> Tensor: + def forward(self, tokens: List[str]) -> Tensor: r"""Calls the `lookup_vectors` method Args: tokens: a list of tokens diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 798f9f4186..752887dd60 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -128,7 +128,7 @@ def is_jitable(self): return not isinstance(self.vocab, VocabPybind) @torch.jit.export - def __call__(self, tokens: List[str]) -> List[int]: + def forward(self, tokens: List[str]) -> List[int]: r"""Calls the `lookup_indices` method Args: tokens (List[str]): the tokens used to lookup their corresponding `indices`.