Skip to content

Commit

Permalink
switch to_ivalue to __prepare_scriptable__ (#1080)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangguanheng66 authored Feb 9, 2021
1 parent 4e295e4 commit e368dc9
Show file tree
Hide file tree
Showing 14 changed files with 70 additions and 93 deletions.
2 changes: 1 addition & 1 deletion benchmark/benchmark_basic_english_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def _run_benchmark_lookup(train, tokenizer):

existing_basic_english_tokenizer = get_tokenizer("basic_english")
experimental_basic_english_normalize = basic_english_normalize()
experimental_jit_basic_english_normalize = torch.jit.script(experimental_basic_english_normalize.to_ivalue())
experimental_jit_basic_english_normalize = torch.jit.script(experimental_basic_english_normalize)

# existing eager lookup
train, _ = AG_NEWS()
Expand Down
2 changes: 1 addition & 1 deletion benchmark/benchmark_experimental_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _run_benchmark_lookup(tokens, vector):

# experimental FastText jit lookup
print("FastText Experimental - Jit Mode")
jit_fast_text_experimental = torch.jit.script(fast_text_experimental.to_ivalue())
jit_fast_text_experimental = torch.jit.script(fast_text_experimental)
_run_benchmark_lookup(tokens, jit_fast_text_experimental)


Expand Down
6 changes: 3 additions & 3 deletions benchmark/benchmark_experimental_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def benchmark_experimental_vocab_construction(vocab_file_path, is_raw_text=True,
print("Loading from raw text file with basic_english_normalize tokenizer")
for _ in range(num_iters):
tokenizer = basic_english_normalize()
jited_tokenizer = torch.jit.script(tokenizer.to_ivalue())
jited_tokenizer = torch.jit.script(tokenizer)
build_vocab_from_text_file(f, jited_tokenizer, num_cpus=1)
print("Construction time:", time.monotonic() - t0)
else:
Expand Down Expand Up @@ -140,7 +140,7 @@ def token_iterator(file_path):
t0 = time.monotonic()
v_experimental = VocabExperimental(ordered_dict)
print("Construction time:", time.monotonic() - t0)
jit_v_experimental = torch.jit.script(v_experimental.to_ivalue())
jit_v_experimental = torch.jit.script(v_experimental)

# existing Vocab eager lookup
print("Vocab - Eager Mode")
Expand All @@ -154,7 +154,7 @@ def token_iterator(file_path):
_run_benchmark_lookup([tokens], v_experimental)
_run_benchmark_lookup(tokens_lists, v_experimental)

jit_v_experimental = torch.jit.script(v_experimental.to_ivalue())
jit_v_experimental = torch.jit.script(v_experimental)
# experimental Vocab jit lookup
print("Vocab Experimental - Jit Mode")
_run_benchmark_lookup(tokens, jit_v_experimental)
Expand Down
2 changes: 1 addition & 1 deletion benchmark/benchmark_pytext_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def benchmark_experimental_vocab():
t0 = time.monotonic()
experimental_script_vocab = ExperimentalScriptVocabulary(ordered_dict, unk_token="<unk>")
print("Construction time:", time.monotonic() - t0)
jit_experimental_script_vocab = torch.jit.script(experimental_script_vocab.to_ivalue())
jit_experimental_script_vocab = torch.jit.script(experimental_script_vocab)

# pytext Vocab eager lookup
print("Pytext Vocabulary - Eager Mode")
Expand Down
21 changes: 10 additions & 11 deletions examples/data_pipeline/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,10 @@ def build_sp_pipeline(spm_file):
vocab = PretrainedSPVocab(load_sp_model(spm_file))

# Insert token in vocab to match a pretrained vocab
vocab.insert_token('<pad>', 1)
pipeline = TextSequentialTransforms(tokenizer, vocab)
jit_pipeline = torch.jit.script(pipeline.to_ivalue())
jit_pipeline = torch.jit.script(pipeline)
print('jit sentencepiece pipeline success!')
return pipeline, pipeline.to_ivalue(), jit_pipeline
return pipeline, pipeline, jit_pipeline


def build_legacy_torchtext_vocab_pipeline(vocab_file):
Expand All @@ -59,9 +58,9 @@ def build_experimental_torchtext_pipeline(hf_vocab_file):
with open(hf_vocab_file, 'r') as f:
vocab = load_vocab_from_file(f)
pipeline = TextSequentialTransforms(tokenizer, vocab)
jit_pipeline = torch.jit.script(pipeline.to_ivalue())
jit_pipeline = torch.jit.script(pipeline)
print('jit experimental torchtext pipeline success!')
return pipeline, pipeline.to_ivalue(), jit_pipeline
return pipeline, pipeline, jit_pipeline


def build_legacy_batch_torchtext_vocab_pipeline(vocab_file):
Expand Down Expand Up @@ -104,9 +103,9 @@ def build_legacy_pytext_script_vocab_pipeline(vocab_file):
vocab_list.insert(0, "<unk>")
pipeline = TextSequentialTransforms(tokenizer,
PyTextScriptVocabTransform(ScriptVocabulary(vocab_list)))
jit_pipeline = torch.jit.script(pipeline.to_ivalue())
jit_pipeline = torch.jit.script(pipeline)
print('jit legacy PyText pipeline success!')
return pipeline, pipeline.to_ivalue(), jit_pipeline
return pipeline, pipeline, jit_pipeline


def build_experimental_pytext_script_pipeline(vocab_file):
Expand All @@ -125,9 +124,9 @@ def build_experimental_pytext_script_pipeline(vocab_file):
# Insert token in vocab to match a pretrained vocab
pipeline = TextSequentialTransforms(tokenizer,
PyTextScriptVocabTransform(script_vocab(ordered_dict)))
jit_pipeline = torch.jit.script(pipeline.to_ivalue())
jit_pipeline = torch.jit.script(pipeline)
print('jit legacy PyText pipeline success!')
return pipeline, pipeline.to_ivalue(), jit_pipeline
return pipeline, pipeline, jit_pipeline


def build_legacy_fasttext_vector_pipeline():
Expand All @@ -143,10 +142,10 @@ def build_experimental_fasttext_vector_pipeline():
vector = FastTextExperimental()

pipeline = TextSequentialTransforms(tokenizer, vector)
jit_pipeline = torch.jit.script(pipeline.to_ivalue())
jit_pipeline = torch.jit.script(pipeline)

print('jit legacy fasttext pipeline success!')
return pipeline, pipeline.to_ivalue(), jit_pipeline
return pipeline, pipeline, jit_pipeline


def run_benchmark_lookup(text_classification_dataset, pipeline):
Expand Down
14 changes: 0 additions & 14 deletions examples/data_pipeline/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,6 @@ def forward(self, tokens: List[str]) -> List[int]:
def insert_token(self, token: str, index: int) -> None:
self.vocab.insert_token(token, index)

def to_ivalue(self):
if hasattr(self.vocab, 'to_ivalue'):
sp_model = self.sp_model
new_module = PretrainedSPVocab(sp_model)
new_module.vocab = self.vocab.to_ivalue()
return new_module
return self


class PyTextVocabTransform(nn.Module):
r"""PyTextVocabTransform transform
Expand All @@ -57,12 +49,6 @@ def __init__(self, vocab):
def forward(self, tokens: List[str]) -> List[int]:
return self.vocab.lookup_indices_1d(tokens)

def to_ivalue(self):
if hasattr(self.vocab, 'to_ivalue'):
vocab = self.vocab.to_ivalue()
return PyTextScriptVocabTransform(vocab)
return self


class ToLongTensor(nn.Module):
r"""Convert a list of integers to long tensor
Expand Down
20 changes: 14 additions & 6 deletions test/data/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,16 @@ def test_BasicEnglishNormalize(self):
basic_eng_norm = basic_english_normalize()
experimental_eager_tokens = basic_eng_norm(test_sample)

jit_basic_eng_norm = torch.jit.script(basic_eng_norm.to_ivalue())
jit_basic_eng_norm = torch.jit.script(basic_eng_norm)
experimental_jit_tokens = jit_basic_eng_norm(test_sample)

basic_english_tokenizer = data.get_tokenizer("basic_english")
eager_tokens = basic_english_tokenizer(test_sample)

assert not basic_eng_norm.is_jitable
assert basic_eng_norm.to_ivalue().is_jitable
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
assert basic_eng_norm.__prepare_scriptable__().is_jitable

self.assertEqual(experimental_jit_tokens, ref_results)
self.assertEqual(eager_tokens, ref_results)
Expand All @@ -121,7 +123,9 @@ def test_basicEnglishNormalize_load_and_save(self):

with self.subTest('torchscript'):
save_path = os.path.join(self.test_dir, 'ben_torchscrip.pt')
ben = basic_english_normalize().to_ivalue()
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
ben = basic_english_normalize().__prepare_scriptable__()
torch.save(ben, save_path)
loaded_ben = torch.load(save_path)
self.assertEqual(loaded_ben(test_sample), ref_results)
Expand Down Expand Up @@ -149,11 +153,13 @@ def test_RegexTokenizer(self):
r_tokenizer = regex_tokenizer(patterns_list)
eager_tokens = r_tokenizer(test_sample)

jit_r_tokenizer = torch.jit.script(r_tokenizer.to_ivalue())
jit_r_tokenizer = torch.jit.script(r_tokenizer)
jit_tokens = jit_r_tokenizer(test_sample)

assert not r_tokenizer.is_jitable
assert r_tokenizer.to_ivalue().is_jitable
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
assert r_tokenizer.__prepare_scriptable__().is_jitable

self.assertEqual(eager_tokens, ref_results)
self.assertEqual(jit_tokens, ref_results)
Expand Down Expand Up @@ -186,7 +192,9 @@ def test_load_and_save(self):

with self.subTest('torchscript'):
save_path = os.path.join(self.test_dir, 'regex_torchscript.pt')
tokenizer = regex_tokenizer(patterns_list).to_ivalue()
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
tokenizer = regex_tokenizer(patterns_list).__prepare_scriptable__()
torch.save(tokenizer, save_path)
loaded_tokenizer = torch.load(save_path)
results = loaded_tokenizer(test_sample)
Expand Down
10 changes: 6 additions & 4 deletions test/experimental/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class TestTransforms(TorchtextTestCase):
def test_sentencepiece_processor(self):
model_path = get_asset_path('spm_example.model')
spm_transform = sentencepiece_processor(model_path)
jit_spm_transform = torch.jit.script(spm_transform.to_ivalue())
jit_spm_transform = torch.jit.script(spm_transform)
test_sample = 'SentencePiece is an unsupervised text tokenizer and detokenizer'
ref_results = [15340, 4286, 981, 1207, 1681, 17, 84, 684, 8896, 5366,
144, 3689, 9, 5602, 12114, 6, 560, 649, 5602, 12114]
Expand All @@ -28,7 +28,7 @@ def test_sentencepiece_processor(self):
def test_sentencepiece_tokenizer(self):
model_path = get_asset_path('spm_example.model')
spm_tokenizer = sentencepiece_tokenizer(model_path)
jit_spm_tokenizer = torch.jit.script(spm_tokenizer.to_ivalue())
jit_spm_tokenizer = torch.jit.script(spm_tokenizer)
test_sample = 'SentencePiece is an unsupervised text tokenizer and detokenizer'
ref_results = ['\u2581Sent', 'ence', 'P', 'ie', 'ce', '\u2581is',
'\u2581an', '\u2581un', 'super', 'vis', 'ed', '\u2581text',
Expand All @@ -48,7 +48,7 @@ def test_vector_transform(self):
data_path = os.path.join(dir_name, asset_name)
shutil.copy(asset_path, data_path)
vector_transform = VectorTransform(FastText(root=dir_name, validate_file=False))
jit_vector_transform = torch.jit.script(vector_transform.to_ivalue())
jit_vector_transform = torch.jit.script(vector_transform)
# The first 3 entries in each vector.
expected_fasttext_simple_en = torch.tensor([[-0.065334, -0.093031, -0.017571],
[-0.32423, -0.098845, -0.0073467]])
Expand All @@ -74,7 +74,9 @@ def test_sentencepiece_load_and_save(self):

with self.subTest('torchscript'):
save_path = os.path.join(self.test_dir, 'spm_torchscript.pt')
spm = sentencepiece_tokenizer((model_path)).to_ivalue()
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
spm = sentencepiece_tokenizer((model_path)).__prepare_scriptable__()
torch.save(spm, save_path)
loaded_spm = torch.load(save_path)
self.assertEqual(expected, loaded_spm(input))
12 changes: 8 additions & 4 deletions test/experimental/test_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ def test_vectors_jit(self):
tokens = ['a', 'b']
vecs = torch.stack((tensorA, tensorB), 0)
vectors_obj = build_vectors(tokens, vecs, unk_tensor=unk_tensor)
jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue())
jit_vectors_obj = torch.jit.script(vectors_obj)

assert not vectors_obj.is_jitable
assert vectors_obj.to_ivalue().is_jitable
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
assert vectors_obj.__prepare_scriptable__().is_jitable

self.assertEqual(vectors_obj['a'], jit_vectors_obj['a'])
self.assertEqual(vectors_obj['b'], jit_vectors_obj['b'])
Expand All @@ -71,7 +73,7 @@ def test_vectors_forward(self):
tokens = ['a', 'b']
vecs = torch.stack((tensorA, tensorB), 0)
vectors_obj = build_vectors(tokens, vecs, unk_tensor=unk_tensor)
jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue())
jit_vectors_obj = torch.jit.script(vectors_obj)

tokens_to_lookup = ['a', 'b', 'c']
expected_vectors = torch.stack((tensorA, tensorB, unk_tensor), 0)
Expand Down Expand Up @@ -148,7 +150,9 @@ def test_vectors_load_and_save(self):

with self.subTest('torchscript'):
vector_path = os.path.join(self.test_dir, 'vectors_torchscript.pt')
torch.save(vectors_obj.to_ivalue(), vector_path)
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
torch.save(vectors_obj.__prepare_scriptable__(), vector_path)
loaded_vectors_obj = torch.load(vector_path)

self.assertEqual(loaded_vectors_obj['a'], tensorA)
Expand Down
12 changes: 8 additions & 4 deletions test/experimental/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,15 @@ def test_vocab_jit(self):

c = OrderedDict(sorted_by_freq_tuples)
v = vocab(c, min_freq=3)
jit_v = torch.jit.script(v.to_ivalue())
jit_v = torch.jit.script(v)

expected_itos = ['<unk>', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}

assert not v.is_jitable
assert v.to_ivalue().is_jitable
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
assert v.__prepare_scriptable__().is_jitable

self.assertEqual(jit_v.get_itos(), expected_itos)
self.assertEqual(dict(jit_v.get_stoi()), expected_stoi)
Expand All @@ -121,7 +123,7 @@ def test_vocab_forward(self):

c = OrderedDict(sorted_by_freq_tuples)
v = vocab(c)
jit_v = torch.jit.script(v.to_ivalue())
jit_v = torch.jit.script(v)

tokens = ['b', 'a', 'c']
expected_indices = [2, 1, 3]
Expand Down Expand Up @@ -208,7 +210,9 @@ def test_vocab_load_and_save(self):

with self.subTest('torchscript'):
vocab_path = os.path.join(self.test_dir, 'vocab_torchscript.pt')
torch.save(v.to_ivalue(), vocab_path)
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
torch.save(v.__prepare_scriptable__(), vocab_path)
loaded_v = torch.load(vocab_path)
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
Expand Down
10 changes: 5 additions & 5 deletions test/experimental/test_with_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_vocab_transform(self):
vocab_transform = VocabTransform(load_vocab_from_file(f))
self.assertEqual(vocab_transform(['of', 'that', 'new']),
[7, 18, 24])
jit_vocab_transform = torch.jit.script(vocab_transform.to_ivalue())
jit_vocab_transform = torch.jit.script(vocab_transform)
self.assertEqual(jit_vocab_transform(['of', 'that', 'new', 'that']),
[7, 18, 24, 18])

Expand Down Expand Up @@ -128,7 +128,7 @@ def test_glove(self):
data_path = os.path.join(dir_name, asset_name)
shutil.copy(asset_path, data_path)
vectors_obj = GloVe(root=dir_name, validate_file=False)
jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue())
jit_vectors_obj = torch.jit.script(vectors_obj)

# The first 3 entries in each vector.
expected_glove = {
Expand Down Expand Up @@ -191,7 +191,7 @@ def test_vocab_from_raw_text_file(self):
asset_path = get_asset_path(asset_name)
with open(asset_path, 'r') as f:
tokenizer = basic_english_normalize()
jit_tokenizer = torch.jit.script(tokenizer.to_ivalue())
jit_tokenizer = torch.jit.script(tokenizer)
v = build_vocab_from_text_file(f, jit_tokenizer, unk_token='<new_unk>')
expected_itos = ['<new_unk>', "'", 'after', 'talks', '.', 'are', 'at', 'disappointed',
'fears', 'federal', 'firm', 'for', 'mogul', 'n', 'newall', 'parent',
Expand Down Expand Up @@ -243,7 +243,7 @@ def test_text_sequential_transform(self):
asset_path = get_asset_path(asset_name)
with open(asset_path, 'r') as f:
pipeline = TextSequentialTransforms(basic_english_normalize(), load_vocab_from_file(f))
jit_pipeline = torch.jit.script(pipeline.to_ivalue())
jit_pipeline = torch.jit.script(pipeline)
self.assertEqual(pipeline('of that new'), [7, 18, 24])
self.assertEqual(jit_pipeline('of that new'), [7, 18, 24])

Expand All @@ -270,7 +270,7 @@ def test_fast_text(self):
data_path = os.path.join(dir_name, asset_name)
shutil.copy(asset_path, data_path)
vectors_obj = FastText(root=dir_name, validate_file=False)
jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue())
jit_vectors_obj = torch.jit.script(vectors_obj)

# The first 3 entries in each vector.
expected_fasttext_simple_en = {
Expand Down
Loading

0 comments on commit e368dc9

Please sign in to comment.