Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switch to_ivalue to __prepare_scriptable__ #1080

Merged
merged 15 commits into from
Feb 9, 2021
Merged
2 changes: 1 addition & 1 deletion benchmark/benchmark_basic_english_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def _run_benchmark_lookup(train, tokenizer):

existing_basic_english_tokenizer = get_tokenizer("basic_english")
experimental_basic_english_normalize = basic_english_normalize()
experimental_jit_basic_english_normalize = torch.jit.script(experimental_basic_english_normalize.to_ivalue())
experimental_jit_basic_english_normalize = torch.jit.script(experimental_basic_english_normalize)

# existing eager lookup
train, _ = AG_NEWS()
Expand Down
2 changes: 1 addition & 1 deletion benchmark/benchmark_experimental_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _run_benchmark_lookup(tokens, vector):

# experimental FastText jit lookup
print("FastText Experimental - Jit Mode")
jit_fast_text_experimental = torch.jit.script(fast_text_experimental.to_ivalue())
jit_fast_text_experimental = torch.jit.script(fast_text_experimental)
_run_benchmark_lookup(tokens, jit_fast_text_experimental)


Expand Down
6 changes: 3 additions & 3 deletions benchmark/benchmark_experimental_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def benchmark_experimental_vocab_construction(vocab_file_path, is_raw_text=True,
print("Loading from raw text file with basic_english_normalize tokenizer")
for _ in range(num_iters):
tokenizer = basic_english_normalize()
jited_tokenizer = torch.jit.script(tokenizer.to_ivalue())
jited_tokenizer = torch.jit.script(tokenizer)
build_vocab_from_text_file(f, jited_tokenizer, num_cpus=1)
print("Construction time:", time.monotonic() - t0)
else:
Expand Down Expand Up @@ -140,7 +140,7 @@ def token_iterator(file_path):
t0 = time.monotonic()
v_experimental = VocabExperimental(ordered_dict)
print("Construction time:", time.monotonic() - t0)
jit_v_experimental = torch.jit.script(v_experimental.to_ivalue())
jit_v_experimental = torch.jit.script(v_experimental)

# existing Vocab eager lookup
print("Vocab - Eager Mode")
Expand All @@ -154,7 +154,7 @@ def token_iterator(file_path):
_run_benchmark_lookup([tokens], v_experimental)
_run_benchmark_lookup(tokens_lists, v_experimental)

jit_v_experimental = torch.jit.script(v_experimental.to_ivalue())
jit_v_experimental = torch.jit.script(v_experimental)
# experimental Vocab jit lookup
print("Vocab Experimental - Jit Mode")
_run_benchmark_lookup(tokens, jit_v_experimental)
Expand Down
2 changes: 1 addition & 1 deletion benchmark/benchmark_pytext_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def benchmark_experimental_vocab():
t0 = time.monotonic()
experimental_script_vocab = ExperimentalScriptVocabulary(ordered_dict, unk_token="<unk>")
print("Construction time:", time.monotonic() - t0)
jit_experimental_script_vocab = torch.jit.script(experimental_script_vocab.to_ivalue())
jit_experimental_script_vocab = torch.jit.script(experimental_script_vocab)

# pytext Vocab eager lookup
print("Pytext Vocabulary - Eager Mode")
Expand Down
21 changes: 10 additions & 11 deletions examples/data_pipeline/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,10 @@ def build_sp_pipeline(spm_file):
vocab = PretrainedSPVocab(load_sp_model(spm_file))

# Insert token in vocab to match a pretrained vocab
vocab.insert_token('<pad>', 1)
pipeline = TextSequentialTransforms(tokenizer, vocab)
jit_pipeline = torch.jit.script(pipeline.to_ivalue())
jit_pipeline = torch.jit.script(pipeline)
print('jit sentencepiece pipeline success!')
return pipeline, pipeline.to_ivalue(), jit_pipeline
return pipeline, pipeline, jit_pipeline


def build_legacy_torchtext_vocab_pipeline(vocab_file):
Expand All @@ -59,9 +58,9 @@ def build_experimental_torchtext_pipeline(hf_vocab_file):
with open(hf_vocab_file, 'r') as f:
vocab = load_vocab_from_file(f)
pipeline = TextSequentialTransforms(tokenizer, vocab)
jit_pipeline = torch.jit.script(pipeline.to_ivalue())
jit_pipeline = torch.jit.script(pipeline)
print('jit experimental torchtext pipeline success!')
return pipeline, pipeline.to_ivalue(), jit_pipeline
return pipeline, pipeline, jit_pipeline


def build_legacy_batch_torchtext_vocab_pipeline(vocab_file):
Expand Down Expand Up @@ -104,9 +103,9 @@ def build_legacy_pytext_script_vocab_pipeline(vocab_file):
vocab_list.insert(0, "<unk>")
pipeline = TextSequentialTransforms(tokenizer,
PyTextScriptVocabTransform(ScriptVocabulary(vocab_list)))
jit_pipeline = torch.jit.script(pipeline.to_ivalue())
jit_pipeline = torch.jit.script(pipeline)
print('jit legacy PyText pipeline success!')
return pipeline, pipeline.to_ivalue(), jit_pipeline
return pipeline, pipeline, jit_pipeline


def build_experimental_pytext_script_pipeline(vocab_file):
Expand All @@ -125,9 +124,9 @@ def build_experimental_pytext_script_pipeline(vocab_file):
# Insert token in vocab to match a pretrained vocab
pipeline = TextSequentialTransforms(tokenizer,
PyTextScriptVocabTransform(script_vocab(ordered_dict)))
jit_pipeline = torch.jit.script(pipeline.to_ivalue())
jit_pipeline = torch.jit.script(pipeline)
print('jit legacy PyText pipeline success!')
return pipeline, pipeline.to_ivalue(), jit_pipeline
return pipeline, pipeline, jit_pipeline


def build_legacy_fasttext_vector_pipeline():
Expand All @@ -143,10 +142,10 @@ def build_experimental_fasttext_vector_pipeline():
vector = FastTextExperimental()

pipeline = TextSequentialTransforms(tokenizer, vector)
jit_pipeline = torch.jit.script(pipeline.to_ivalue())
jit_pipeline = torch.jit.script(pipeline)

print('jit legacy fasttext pipeline success!')
return pipeline, pipeline.to_ivalue(), jit_pipeline
return pipeline, pipeline, jit_pipeline


def run_benchmark_lookup(text_classification_dataset, pipeline):
Expand Down
14 changes: 0 additions & 14 deletions examples/data_pipeline/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,6 @@ def forward(self, tokens: List[str]) -> List[int]:
def insert_token(self, token: str, index: int) -> None:
self.vocab.insert_token(token, index)

def to_ivalue(self):
if hasattr(self.vocab, 'to_ivalue'):
sp_model = self.sp_model
new_module = PretrainedSPVocab(sp_model)
new_module.vocab = self.vocab.to_ivalue()
return new_module
return self


class PyTextVocabTransform(nn.Module):
r"""PyTextVocabTransform transform
Expand All @@ -57,12 +49,6 @@ def __init__(self, vocab):
def forward(self, tokens: List[str]) -> List[int]:
return self.vocab.lookup_indices_1d(tokens)

def to_ivalue(self):
if hasattr(self.vocab, 'to_ivalue'):
vocab = self.vocab.to_ivalue()
return PyTextScriptVocabTransform(vocab)
return self


class ToLongTensor(nn.Module):
r"""Convert a list of integers to long tensor
Expand Down
20 changes: 14 additions & 6 deletions test/data/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,16 @@ def test_BasicEnglishNormalize(self):
basic_eng_norm = basic_english_normalize()
experimental_eager_tokens = basic_eng_norm(test_sample)

jit_basic_eng_norm = torch.jit.script(basic_eng_norm.to_ivalue())
jit_basic_eng_norm = torch.jit.script(basic_eng_norm)
experimental_jit_tokens = jit_basic_eng_norm(test_sample)

basic_english_tokenizer = data.get_tokenizer("basic_english")
eager_tokens = basic_english_tokenizer(test_sample)

assert not basic_eng_norm.is_jitable
assert basic_eng_norm.to_ivalue().is_jitable
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
assert basic_eng_norm.__prepare_scriptable__().is_jitable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd interact with this via torch.jit.script. These methods are not meant to be called directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we want to check if the building block backed by torchbind is jitable. Although we don't recommend the torchbind version for eager mode, I think we still need a test to cover it. By calling torch.jit.script, we are testing the jit mode, which is not the same thing with the torchbind building block.

Same thing for the pickle support.


self.assertEqual(experimental_jit_tokens, ref_results)
self.assertEqual(eager_tokens, ref_results)
Expand All @@ -121,7 +123,9 @@ def test_basicEnglishNormalize_load_and_save(self):

with self.subTest('torchscript'):
save_path = os.path.join(self.test_dir, 'ben_torchscrip.pt')
ben = basic_english_normalize().to_ivalue()
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
ben = basic_english_normalize().__prepare_scriptable__()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd interact with this via torch.jit.script. These methods are not meant to be called directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comment above.

torch.save(ben, save_path)
loaded_ben = torch.load(save_path)
self.assertEqual(loaded_ben(test_sample), ref_results)
Expand Down Expand Up @@ -149,11 +153,13 @@ def test_RegexTokenizer(self):
r_tokenizer = regex_tokenizer(patterns_list)
eager_tokens = r_tokenizer(test_sample)

jit_r_tokenizer = torch.jit.script(r_tokenizer.to_ivalue())
jit_r_tokenizer = torch.jit.script(r_tokenizer)
jit_tokens = jit_r_tokenizer(test_sample)

assert not r_tokenizer.is_jitable
assert r_tokenizer.to_ivalue().is_jitable
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
assert r_tokenizer.__prepare_scriptable__().is_jitable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comment above.


self.assertEqual(eager_tokens, ref_results)
self.assertEqual(jit_tokens, ref_results)
Expand Down Expand Up @@ -186,7 +192,9 @@ def test_load_and_save(self):

with self.subTest('torchscript'):
save_path = os.path.join(self.test_dir, 'regex_torchscript.pt')
tokenizer = regex_tokenizer(patterns_list).to_ivalue()
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
tokenizer = regex_tokenizer(patterns_list).__prepare_scriptable__()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we shouldn't have to JIT anything just to save it, but if we want to exercise saving a JIT'd model then we should call torch.jit.script first

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comment above. We want to save a torchbind version model here, not exactly the JIT mode.

torch.save(tokenizer, save_path)
loaded_tokenizer = torch.load(save_path)
results = loaded_tokenizer(test_sample)
Expand Down
10 changes: 6 additions & 4 deletions test/experimental/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class TestTransforms(TorchtextTestCase):
def test_sentencepiece_processor(self):
model_path = get_asset_path('spm_example.model')
spm_transform = sentencepiece_processor(model_path)
jit_spm_transform = torch.jit.script(spm_transform.to_ivalue())
jit_spm_transform = torch.jit.script(spm_transform)
test_sample = 'SentencePiece is an unsupervised text tokenizer and detokenizer'
ref_results = [15340, 4286, 981, 1207, 1681, 17, 84, 684, 8896, 5366,
144, 3689, 9, 5602, 12114, 6, 560, 649, 5602, 12114]
Expand All @@ -28,7 +28,7 @@ def test_sentencepiece_processor(self):
def test_sentencepiece_tokenizer(self):
model_path = get_asset_path('spm_example.model')
spm_tokenizer = sentencepiece_tokenizer(model_path)
jit_spm_tokenizer = torch.jit.script(spm_tokenizer.to_ivalue())
jit_spm_tokenizer = torch.jit.script(spm_tokenizer)
test_sample = 'SentencePiece is an unsupervised text tokenizer and detokenizer'
ref_results = ['\u2581Sent', 'ence', 'P', 'ie', 'ce', '\u2581is',
'\u2581an', '\u2581un', 'super', 'vis', 'ed', '\u2581text',
Expand All @@ -48,7 +48,7 @@ def test_vector_transform(self):
data_path = os.path.join(dir_name, asset_name)
shutil.copy(asset_path, data_path)
vector_transform = VectorTransform(FastText(root=dir_name, validate_file=False))
jit_vector_transform = torch.jit.script(vector_transform.to_ivalue())
jit_vector_transform = torch.jit.script(vector_transform)
# The first 3 entries in each vector.
expected_fasttext_simple_en = torch.tensor([[-0.065334, -0.093031, -0.017571],
[-0.32423, -0.098845, -0.0073467]])
Expand All @@ -74,7 +74,9 @@ def test_sentencepiece_load_and_save(self):

with self.subTest('torchscript'):
save_path = os.path.join(self.test_dir, 'spm_torchscript.pt')
spm = sentencepiece_tokenizer((model_path)).to_ivalue()
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
spm = sentencepiece_tokenizer((model_path)).__prepare_scriptable__()
zhangguanheng66 marked this conversation as resolved.
Show resolved Hide resolved
torch.save(spm, save_path)
loaded_spm = torch.load(save_path)
self.assertEqual(expected, loaded_spm(input))
12 changes: 8 additions & 4 deletions test/experimental/test_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ def test_vectors_jit(self):
tokens = ['a', 'b']
vecs = torch.stack((tensorA, tensorB), 0)
vectors_obj = build_vectors(tokens, vecs, unk_tensor=unk_tensor)
jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue())
jit_vectors_obj = torch.jit.script(vectors_obj)

assert not vectors_obj.is_jitable
assert vectors_obj.to_ivalue().is_jitable
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
assert vectors_obj.__prepare_scriptable__().is_jitable

self.assertEqual(vectors_obj['a'], jit_vectors_obj['a'])
self.assertEqual(vectors_obj['b'], jit_vectors_obj['b'])
Expand All @@ -71,7 +73,7 @@ def test_vectors_forward(self):
tokens = ['a', 'b']
vecs = torch.stack((tensorA, tensorB), 0)
vectors_obj = build_vectors(tokens, vecs, unk_tensor=unk_tensor)
jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue())
jit_vectors_obj = torch.jit.script(vectors_obj)

tokens_to_lookup = ['a', 'b', 'c']
expected_vectors = torch.stack((tensorA, tensorB, unk_tensor), 0)
Expand Down Expand Up @@ -148,7 +150,9 @@ def test_vectors_load_and_save(self):

with self.subTest('torchscript'):
vector_path = os.path.join(self.test_dir, 'vectors_torchscript.pt')
torch.save(vectors_obj.to_ivalue(), vector_path)
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
torch.save(vectors_obj.__prepare_scriptable__(), vector_path)
loaded_vectors_obj = torch.load(vector_path)

self.assertEqual(loaded_vectors_obj['a'], tensorA)
Expand Down
12 changes: 8 additions & 4 deletions test/experimental/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,15 @@ def test_vocab_jit(self):

c = OrderedDict(sorted_by_freq_tuples)
v = vocab(c, min_freq=3)
jit_v = torch.jit.script(v.to_ivalue())
jit_v = torch.jit.script(v)

expected_itos = ['<unk>', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}

assert not v.is_jitable
assert v.to_ivalue().is_jitable
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
assert v.__prepare_scriptable__().is_jitable

self.assertEqual(jit_v.get_itos(), expected_itos)
self.assertEqual(dict(jit_v.get_stoi()), expected_stoi)
Expand All @@ -121,7 +123,7 @@ def test_vocab_forward(self):

c = OrderedDict(sorted_by_freq_tuples)
v = vocab(c)
jit_v = torch.jit.script(v.to_ivalue())
jit_v = torch.jit.script(v)

tokens = ['b', 'a', 'c']
expected_indices = [2, 1, 3]
Expand Down Expand Up @@ -208,7 +210,9 @@ def test_vocab_load_and_save(self):

with self.subTest('torchscript'):
vocab_path = os.path.join(self.test_dir, 'vocab_torchscript.pt')
torch.save(v.to_ivalue(), vocab_path)
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
torch.save(v.__prepare_scriptable__(), vocab_path)
loaded_v = torch.load(vocab_path)
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
Expand Down
10 changes: 5 additions & 5 deletions test/experimental/test_with_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_vocab_transform(self):
vocab_transform = VocabTransform(load_vocab_from_file(f))
self.assertEqual(vocab_transform(['of', 'that', 'new']),
[7, 18, 24])
jit_vocab_transform = torch.jit.script(vocab_transform.to_ivalue())
jit_vocab_transform = torch.jit.script(vocab_transform)
self.assertEqual(jit_vocab_transform(['of', 'that', 'new', 'that']),
[7, 18, 24, 18])

Expand Down Expand Up @@ -128,7 +128,7 @@ def test_glove(self):
data_path = os.path.join(dir_name, asset_name)
shutil.copy(asset_path, data_path)
vectors_obj = GloVe(root=dir_name, validate_file=False)
jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue())
jit_vectors_obj = torch.jit.script(vectors_obj)

# The first 3 entries in each vector.
expected_glove = {
Expand Down Expand Up @@ -191,7 +191,7 @@ def test_vocab_from_raw_text_file(self):
asset_path = get_asset_path(asset_name)
with open(asset_path, 'r') as f:
tokenizer = basic_english_normalize()
jit_tokenizer = torch.jit.script(tokenizer.to_ivalue())
jit_tokenizer = torch.jit.script(tokenizer)
v = build_vocab_from_text_file(f, jit_tokenizer, unk_token='<new_unk>')
expected_itos = ['<new_unk>', "'", 'after', 'talks', '.', 'are', 'at', 'disappointed',
'fears', 'federal', 'firm', 'for', 'mogul', 'n', 'newall', 'parent',
Expand Down Expand Up @@ -243,7 +243,7 @@ def test_text_sequential_transform(self):
asset_path = get_asset_path(asset_name)
with open(asset_path, 'r') as f:
pipeline = TextSequentialTransforms(basic_english_normalize(), load_vocab_from_file(f))
jit_pipeline = torch.jit.script(pipeline.to_ivalue())
jit_pipeline = torch.jit.script(pipeline)
self.assertEqual(pipeline('of that new'), [7, 18, 24])
self.assertEqual(jit_pipeline('of that new'), [7, 18, 24])

Expand All @@ -270,7 +270,7 @@ def test_fast_text(self):
data_path = os.path.join(dir_name, asset_name)
shutil.copy(asset_path, data_path)
vectors_obj = FastText(root=dir_name, validate_file=False)
jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue())
jit_vectors_obj = torch.jit.script(vectors_obj)

# The first 3 entries in each vector.
expected_fasttext_simple_en = {
Expand Down
Loading