Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function to compare batch look-up for vocab #1290

Closed
wants to merge 4 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
modified tests
parmeet committed Apr 20, 2021
commit d54d92fd385a4e790438e308ae8d97ce59ff5e89
34 changes: 2 additions & 32 deletions test/experimental/test_vocab.py
Original file line number Diff line number Diff line change
@@ -5,12 +5,9 @@
import torch
import unittest
from test.common.torchtext_test_case import TorchtextTestCase
from torchtext.experimental.transforms import basic_english_normalize
from torchtext.experimental.vocab import (
vocab,
build_vocab_from_iterator,
build_vocab_from_text_file,
load_vocab_from_file,
)


@@ -220,38 +217,11 @@ def test_vocab_load_and_save(self):
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)

def test_build_vocab_from_vocab_file(self):
iterator = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world', 'freq_low']
with self.subTest('buildfromvocabfile'):
vocab_path = os.path.join(self.test_dir, 'vocab.txt')
with open(vocab_path, 'w') as f:
f.write('\n'.join(iterator))
v = load_vocab_from_file(vocab_path)
expected_itos = ['<unk>', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world', 'freq_low']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)

def test_build_vocab_from_text_file(self):
iterator = ['hello', 'hello', 'hello', 'freq_low', 'hello', 'world', 'world', 'world', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T',
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'freq_low', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T']
with self.subTest('buildfromtextfile'):
vocab_path = os.path.join(self.test_dir, 'vocab.txt')
with open(vocab_path, 'w') as f:
f.write(' '.join(iterator))
f.write('\n')
tokenizer = torch.jit.script(basic_english_normalize())
v = build_vocab_from_text_file(vocab_path, tokenizer)
expected_itos = ['<unk>', 'ᑌᑎiᑕoᗪᕮ_tᕮ᙭t', 'hello', 'world', 'freq_low']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)

def test_build_vocab_iterator(self):
iterator = [['hello', 'hello', 'hello', 'freq_low', 'hello', 'world', 'world', 'world', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T',
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'freq_low', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T']]
'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'freq_low', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T']]
v = build_vocab_from_iterator(iterator)
expected_itos = ['<unk>', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world', 'freq_low']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)
self.assertEqual(dict(v.get_stoi()), expected_stoi)
54 changes: 25 additions & 29 deletions test/experimental/test_with_asset.py
Original file line number Diff line number Diff line change
@@ -78,13 +78,12 @@ class TestTransformsWithAsset(TorchtextTestCase):
def test_vocab_transform(self):
asset_name = 'vocab_test2.txt'
asset_path = get_asset_path(asset_name)
with open(asset_path, 'r') as f:
vocab_transform = VocabTransform(load_vocab_from_file(f))
self.assertEqual(vocab_transform(['of', 'that', 'new']),
[7, 18, 24])
jit_vocab_transform = torch.jit.script(vocab_transform)
self.assertEqual(jit_vocab_transform(['of', 'that', 'new', 'that']),
[7, 18, 24, 18])
vocab_transform = VocabTransform(load_vocab_from_file(asset_path))
self.assertEqual(vocab_transform(['of', 'that', 'new']),
[7, 18, 24])
jit_vocab_transform = torch.jit.script(vocab_transform)
self.assertEqual(jit_vocab_transform(['of', 'that', 'new', 'that']),
[7, 18, 24, 18])

def test_errors_vectors_python(self):
tokens = []
@@ -179,27 +178,25 @@ def test_glove_different_dims(self):
def test_vocab_from_file(self):
asset_name = 'vocab_test.txt'
asset_path = get_asset_path(asset_name)
with open(asset_path, 'r') as f:
v = load_vocab_from_file(f, unk_token='<new_unk>')
expected_itos = ['<new_unk>', 'b', 'a', 'c']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)
v = load_vocab_from_file(asset_path, unk_token='<new_unk>')
expected_itos = ['<new_unk>', 'b', 'a', 'c']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)

def test_vocab_from_raw_text_file(self):
asset_name = 'vocab_raw_text_test.txt'
asset_path = get_asset_path(asset_name)
with open(asset_path, 'r') as f:
tokenizer = basic_english_normalize()
jit_tokenizer = torch.jit.script(tokenizer)
v = build_vocab_from_text_file(f, jit_tokenizer, unk_token='<new_unk>')
expected_itos = ['<new_unk>', "'", 'after', 'talks', '.', 'are', 'at', 'disappointed',
'fears', 'federal', 'firm', 'for', 'mogul', 'n', 'newall', 'parent',
'pension', 'representing', 'say', 'stricken', 't', 'they', 'turner',
'unions', 'with', 'workers']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)
tokenizer = basic_english_normalize()
jit_tokenizer = torch.jit.script(tokenizer)
v = build_vocab_from_text_file(asset_path, jit_tokenizer, unk_token='<new_unk>')
expected_itos = ['<new_unk>', "'", 'after', 'talks', '.', 'are', 'at', 'disappointed',
'fears', 'federal', 'firm', 'for', 'mogul', 'n', 'newall', 'parent',
'pension', 'representing', 'say', 'stricken', 't', 'they', 'turner',
'unions', 'with', 'workers']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)

def test_builtin_pretrained_sentencepiece_processor(self):
sp_model_path = download_from_url(PRETRAINED_SP_MODEL['text_unigram_25000'])
@@ -241,11 +238,10 @@ def batch_func(data):
def test_text_sequential_transform(self):
asset_name = 'vocab_test2.txt'
asset_path = get_asset_path(asset_name)
with open(asset_path, 'r') as f:
pipeline = TextSequentialTransforms(basic_english_normalize(), load_vocab_from_file(f))
jit_pipeline = torch.jit.script(pipeline)
self.assertEqual(pipeline('of that new'), [7, 18, 24])
self.assertEqual(jit_pipeline('of that new'), [7, 18, 24])
pipeline = TextSequentialTransforms(basic_english_normalize(), load_vocab_from_file(asset_path))
jit_pipeline = torch.jit.script(pipeline)
self.assertEqual(pipeline('of that new'), [7, 18, 24])
self.assertEqual(jit_pipeline('of that new'), [7, 18, 24])

def test_vectors_from_file(self):
asset_name = 'vectors_test.csv'