From e368dc926dc4a58967b8048f7ac011c3c5044c84 Mon Sep 17 00:00:00 2001 From: Guanheng George Zhang <6156351+zhangguanheng66@users.noreply.github.com> Date: Tue, 9 Feb 2021 13:36:55 -0500 Subject: [PATCH] switch to_ivalue to __prepare_scriptable__ (#1080) --- .../benchmark_basic_english_normalize.py | 2 +- benchmark/benchmark_experimental_vectors.py | 2 +- benchmark/benchmark_experimental_vocab.py | 6 +-- benchmark/benchmark_pytext_vocab.py | 2 +- examples/data_pipeline/pipelines.py | 21 ++++----- examples/data_pipeline/transforms.py | 14 ------ test/data/test_functional.py | 20 +++++--- test/experimental/test_transforms.py | 10 ++-- test/experimental/test_vectors.py | 12 +++-- test/experimental/test_vocab.py | 12 +++-- test/experimental/test_with_asset.py | 10 ++-- torchtext/experimental/transforms.py | 46 ++++--------------- torchtext/experimental/vectors.py | 2 +- torchtext/experimental/vocab.py | 4 +- 14 files changed, 70 insertions(+), 93 deletions(-) diff --git a/benchmark/benchmark_basic_english_normalize.py b/benchmark/benchmark_basic_english_normalize.py index d719e748a5..fa395b1299 100644 --- a/benchmark/benchmark_basic_english_normalize.py +++ b/benchmark/benchmark_basic_english_normalize.py @@ -15,7 +15,7 @@ def _run_benchmark_lookup(train, tokenizer): existing_basic_english_tokenizer = get_tokenizer("basic_english") experimental_basic_english_normalize = basic_english_normalize() - experimental_jit_basic_english_normalize = torch.jit.script(experimental_basic_english_normalize.to_ivalue()) + experimental_jit_basic_english_normalize = torch.jit.script(experimental_basic_english_normalize) # existing eager lookup train, _ = AG_NEWS() diff --git a/benchmark/benchmark_experimental_vectors.py b/benchmark/benchmark_experimental_vectors.py index 42fc008370..f644c14e62 100644 --- a/benchmark/benchmark_experimental_vectors.py +++ b/benchmark/benchmark_experimental_vectors.py @@ -42,7 +42,7 @@ def _run_benchmark_lookup(tokens, vector): # experimental FastText jit lookup print("FastText Experimental - Jit Mode") - jit_fast_text_experimental = torch.jit.script(fast_text_experimental.to_ivalue()) + jit_fast_text_experimental = torch.jit.script(fast_text_experimental) _run_benchmark_lookup(tokens, jit_fast_text_experimental) diff --git a/benchmark/benchmark_experimental_vocab.py b/benchmark/benchmark_experimental_vocab.py index 4183c27f6a..f815bf3648 100644 --- a/benchmark/benchmark_experimental_vocab.py +++ b/benchmark/benchmark_experimental_vocab.py @@ -67,7 +67,7 @@ def benchmark_experimental_vocab_construction(vocab_file_path, is_raw_text=True, print("Loading from raw text file with basic_english_normalize tokenizer") for _ in range(num_iters): tokenizer = basic_english_normalize() - jited_tokenizer = torch.jit.script(tokenizer.to_ivalue()) + jited_tokenizer = torch.jit.script(tokenizer) build_vocab_from_text_file(f, jited_tokenizer, num_cpus=1) print("Construction time:", time.monotonic() - t0) else: @@ -140,7 +140,7 @@ def token_iterator(file_path): t0 = time.monotonic() v_experimental = VocabExperimental(ordered_dict) print("Construction time:", time.monotonic() - t0) - jit_v_experimental = torch.jit.script(v_experimental.to_ivalue()) + jit_v_experimental = torch.jit.script(v_experimental) # existing Vocab eager lookup print("Vocab - Eager Mode") @@ -154,7 +154,7 @@ def token_iterator(file_path): _run_benchmark_lookup([tokens], v_experimental) _run_benchmark_lookup(tokens_lists, v_experimental) - jit_v_experimental = torch.jit.script(v_experimental.to_ivalue()) + jit_v_experimental = torch.jit.script(v_experimental) # experimental Vocab jit lookup print("Vocab Experimental - Jit Mode") _run_benchmark_lookup(tokens, jit_v_experimental) diff --git a/benchmark/benchmark_pytext_vocab.py b/benchmark/benchmark_pytext_vocab.py index 2e686dd5dc..6dbe200fd4 100644 --- a/benchmark/benchmark_pytext_vocab.py +++ b/benchmark/benchmark_pytext_vocab.py @@ -150,7 +150,7 @@ def benchmark_experimental_vocab(): t0 = time.monotonic() experimental_script_vocab = ExperimentalScriptVocabulary(ordered_dict, unk_token="") print("Construction time:", time.monotonic() - t0) - jit_experimental_script_vocab = torch.jit.script(experimental_script_vocab.to_ivalue()) + jit_experimental_script_vocab = torch.jit.script(experimental_script_vocab) # pytext Vocab eager lookup print("Pytext Vocabulary - Eager Mode") diff --git a/examples/data_pipeline/pipelines.py b/examples/data_pipeline/pipelines.py index 4e2db98021..d8c6f71f7d 100644 --- a/examples/data_pipeline/pipelines.py +++ b/examples/data_pipeline/pipelines.py @@ -32,11 +32,10 @@ def build_sp_pipeline(spm_file): vocab = PretrainedSPVocab(load_sp_model(spm_file)) # Insert token in vocab to match a pretrained vocab - vocab.insert_token('', 1) pipeline = TextSequentialTransforms(tokenizer, vocab) - jit_pipeline = torch.jit.script(pipeline.to_ivalue()) + jit_pipeline = torch.jit.script(pipeline) print('jit sentencepiece pipeline success!') - return pipeline, pipeline.to_ivalue(), jit_pipeline + return pipeline, pipeline, jit_pipeline def build_legacy_torchtext_vocab_pipeline(vocab_file): @@ -59,9 +58,9 @@ def build_experimental_torchtext_pipeline(hf_vocab_file): with open(hf_vocab_file, 'r') as f: vocab = load_vocab_from_file(f) pipeline = TextSequentialTransforms(tokenizer, vocab) - jit_pipeline = torch.jit.script(pipeline.to_ivalue()) + jit_pipeline = torch.jit.script(pipeline) print('jit experimental torchtext pipeline success!') - return pipeline, pipeline.to_ivalue(), jit_pipeline + return pipeline, pipeline, jit_pipeline def build_legacy_batch_torchtext_vocab_pipeline(vocab_file): @@ -104,9 +103,9 @@ def build_legacy_pytext_script_vocab_pipeline(vocab_file): vocab_list.insert(0, "") pipeline = TextSequentialTransforms(tokenizer, PyTextScriptVocabTransform(ScriptVocabulary(vocab_list))) - jit_pipeline = torch.jit.script(pipeline.to_ivalue()) + jit_pipeline = torch.jit.script(pipeline) print('jit legacy PyText pipeline success!') - return pipeline, pipeline.to_ivalue(), jit_pipeline + return pipeline, pipeline, jit_pipeline def build_experimental_pytext_script_pipeline(vocab_file): @@ -125,9 +124,9 @@ def build_experimental_pytext_script_pipeline(vocab_file): # Insert token in vocab to match a pretrained vocab pipeline = TextSequentialTransforms(tokenizer, PyTextScriptVocabTransform(script_vocab(ordered_dict))) - jit_pipeline = torch.jit.script(pipeline.to_ivalue()) + jit_pipeline = torch.jit.script(pipeline) print('jit legacy PyText pipeline success!') - return pipeline, pipeline.to_ivalue(), jit_pipeline + return pipeline, pipeline, jit_pipeline def build_legacy_fasttext_vector_pipeline(): @@ -143,10 +142,10 @@ def build_experimental_fasttext_vector_pipeline(): vector = FastTextExperimental() pipeline = TextSequentialTransforms(tokenizer, vector) - jit_pipeline = torch.jit.script(pipeline.to_ivalue()) + jit_pipeline = torch.jit.script(pipeline) print('jit legacy fasttext pipeline success!') - return pipeline, pipeline.to_ivalue(), jit_pipeline + return pipeline, pipeline, jit_pipeline def run_benchmark_lookup(text_classification_dataset, pipeline): diff --git a/examples/data_pipeline/transforms.py b/examples/data_pipeline/transforms.py index 7a6d9214e5..2bcb8ff34c 100644 --- a/examples/data_pipeline/transforms.py +++ b/examples/data_pipeline/transforms.py @@ -24,14 +24,6 @@ def forward(self, tokens: List[str]) -> List[int]: def insert_token(self, token: str, index: int) -> None: self.vocab.insert_token(token, index) - def to_ivalue(self): - if hasattr(self.vocab, 'to_ivalue'): - sp_model = self.sp_model - new_module = PretrainedSPVocab(sp_model) - new_module.vocab = self.vocab.to_ivalue() - return new_module - return self - class PyTextVocabTransform(nn.Module): r"""PyTextVocabTransform transform @@ -57,12 +49,6 @@ def __init__(self, vocab): def forward(self, tokens: List[str]) -> List[int]: return self.vocab.lookup_indices_1d(tokens) - def to_ivalue(self): - if hasattr(self.vocab, 'to_ivalue'): - vocab = self.vocab.to_ivalue() - return PyTextScriptVocabTransform(vocab) - return self - class ToLongTensor(nn.Module): r"""Convert a list of integers to long tensor diff --git a/test/data/test_functional.py b/test/data/test_functional.py index 509a0e9fc3..199c022786 100644 --- a/test/data/test_functional.py +++ b/test/data/test_functional.py @@ -94,14 +94,16 @@ def test_BasicEnglishNormalize(self): basic_eng_norm = basic_english_normalize() experimental_eager_tokens = basic_eng_norm(test_sample) - jit_basic_eng_norm = torch.jit.script(basic_eng_norm.to_ivalue()) + jit_basic_eng_norm = torch.jit.script(basic_eng_norm) experimental_jit_tokens = jit_basic_eng_norm(test_sample) basic_english_tokenizer = data.get_tokenizer("basic_english") eager_tokens = basic_english_tokenizer(test_sample) assert not basic_eng_norm.is_jitable - assert basic_eng_norm.to_ivalue().is_jitable + # Call the __prepare_scriptable__() func and convert the building block to the torbhind version + # Not expect users to use the torchbind version on eager mode but still need a CI test here. + assert basic_eng_norm.__prepare_scriptable__().is_jitable self.assertEqual(experimental_jit_tokens, ref_results) self.assertEqual(eager_tokens, ref_results) @@ -121,7 +123,9 @@ def test_basicEnglishNormalize_load_and_save(self): with self.subTest('torchscript'): save_path = os.path.join(self.test_dir, 'ben_torchscrip.pt') - ben = basic_english_normalize().to_ivalue() + # Call the __prepare_scriptable__() func and convert the building block to the torbhind version + # Not expect users to use the torchbind version on eager mode but still need a CI test here. + ben = basic_english_normalize().__prepare_scriptable__() torch.save(ben, save_path) loaded_ben = torch.load(save_path) self.assertEqual(loaded_ben(test_sample), ref_results) @@ -149,11 +153,13 @@ def test_RegexTokenizer(self): r_tokenizer = regex_tokenizer(patterns_list) eager_tokens = r_tokenizer(test_sample) - jit_r_tokenizer = torch.jit.script(r_tokenizer.to_ivalue()) + jit_r_tokenizer = torch.jit.script(r_tokenizer) jit_tokens = jit_r_tokenizer(test_sample) assert not r_tokenizer.is_jitable - assert r_tokenizer.to_ivalue().is_jitable + # Call the __prepare_scriptable__() func and convert the building block to the torbhind version + # Not expect users to use the torchbind version on eager mode but still need a CI test here. + assert r_tokenizer.__prepare_scriptable__().is_jitable self.assertEqual(eager_tokens, ref_results) self.assertEqual(jit_tokens, ref_results) @@ -186,7 +192,9 @@ def test_load_and_save(self): with self.subTest('torchscript'): save_path = os.path.join(self.test_dir, 'regex_torchscript.pt') - tokenizer = regex_tokenizer(patterns_list).to_ivalue() + # Call the __prepare_scriptable__() func and convert the building block to the torbhind version + # Not expect users to use the torchbind version on eager mode but still need a CI test here. + tokenizer = regex_tokenizer(patterns_list).__prepare_scriptable__() torch.save(tokenizer, save_path) loaded_tokenizer = torch.load(save_path) results = loaded_tokenizer(test_sample) diff --git a/test/experimental/test_transforms.py b/test/experimental/test_transforms.py index 1994c3c396..d3cc651ddc 100644 --- a/test/experimental/test_transforms.py +++ b/test/experimental/test_transforms.py @@ -16,7 +16,7 @@ class TestTransforms(TorchtextTestCase): def test_sentencepiece_processor(self): model_path = get_asset_path('spm_example.model') spm_transform = sentencepiece_processor(model_path) - jit_spm_transform = torch.jit.script(spm_transform.to_ivalue()) + jit_spm_transform = torch.jit.script(spm_transform) test_sample = 'SentencePiece is an unsupervised text tokenizer and detokenizer' ref_results = [15340, 4286, 981, 1207, 1681, 17, 84, 684, 8896, 5366, 144, 3689, 9, 5602, 12114, 6, 560, 649, 5602, 12114] @@ -28,7 +28,7 @@ def test_sentencepiece_processor(self): def test_sentencepiece_tokenizer(self): model_path = get_asset_path('spm_example.model') spm_tokenizer = sentencepiece_tokenizer(model_path) - jit_spm_tokenizer = torch.jit.script(spm_tokenizer.to_ivalue()) + jit_spm_tokenizer = torch.jit.script(spm_tokenizer) test_sample = 'SentencePiece is an unsupervised text tokenizer and detokenizer' ref_results = ['\u2581Sent', 'ence', 'P', 'ie', 'ce', '\u2581is', '\u2581an', '\u2581un', 'super', 'vis', 'ed', '\u2581text', @@ -48,7 +48,7 @@ def test_vector_transform(self): data_path = os.path.join(dir_name, asset_name) shutil.copy(asset_path, data_path) vector_transform = VectorTransform(FastText(root=dir_name, validate_file=False)) - jit_vector_transform = torch.jit.script(vector_transform.to_ivalue()) + jit_vector_transform = torch.jit.script(vector_transform) # The first 3 entries in each vector. expected_fasttext_simple_en = torch.tensor([[-0.065334, -0.093031, -0.017571], [-0.32423, -0.098845, -0.0073467]]) @@ -74,7 +74,9 @@ def test_sentencepiece_load_and_save(self): with self.subTest('torchscript'): save_path = os.path.join(self.test_dir, 'spm_torchscript.pt') - spm = sentencepiece_tokenizer((model_path)).to_ivalue() + # Call the __prepare_scriptable__() func and convert the building block to the torbhind version + # Not expect users to use the torchbind version on eager mode but still need a CI test here. + spm = sentencepiece_tokenizer((model_path)).__prepare_scriptable__() torch.save(spm, save_path) loaded_spm = torch.load(save_path) self.assertEqual(expected, loaded_spm(input)) diff --git a/test/experimental/test_vectors.py b/test/experimental/test_vectors.py index ff5293b402..0cb32ffe1d 100644 --- a/test/experimental/test_vectors.py +++ b/test/experimental/test_vectors.py @@ -54,10 +54,12 @@ def test_vectors_jit(self): tokens = ['a', 'b'] vecs = torch.stack((tensorA, tensorB), 0) vectors_obj = build_vectors(tokens, vecs, unk_tensor=unk_tensor) - jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue()) + jit_vectors_obj = torch.jit.script(vectors_obj) assert not vectors_obj.is_jitable - assert vectors_obj.to_ivalue().is_jitable + # Call the __prepare_scriptable__() func and convert the building block to the torbhind version + # Not expect users to use the torchbind version on eager mode but still need a CI test here. + assert vectors_obj.__prepare_scriptable__().is_jitable self.assertEqual(vectors_obj['a'], jit_vectors_obj['a']) self.assertEqual(vectors_obj['b'], jit_vectors_obj['b']) @@ -71,7 +73,7 @@ def test_vectors_forward(self): tokens = ['a', 'b'] vecs = torch.stack((tensorA, tensorB), 0) vectors_obj = build_vectors(tokens, vecs, unk_tensor=unk_tensor) - jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue()) + jit_vectors_obj = torch.jit.script(vectors_obj) tokens_to_lookup = ['a', 'b', 'c'] expected_vectors = torch.stack((tensorA, tensorB, unk_tensor), 0) @@ -148,7 +150,9 @@ def test_vectors_load_and_save(self): with self.subTest('torchscript'): vector_path = os.path.join(self.test_dir, 'vectors_torchscript.pt') - torch.save(vectors_obj.to_ivalue(), vector_path) + # Call the __prepare_scriptable__() func and convert the building block to the torbhind version + # Not expect users to use the torchbind version on eager mode but still need a CI test here. + torch.save(vectors_obj.__prepare_scriptable__(), vector_path) loaded_vectors_obj = torch.load(vector_path) self.assertEqual(loaded_vectors_obj['a'], tensorA) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 662aa6667a..879c03e72d 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -104,13 +104,15 @@ def test_vocab_jit(self): c = OrderedDict(sorted_by_freq_tuples) v = vocab(c, min_freq=3) - jit_v = torch.jit.script(v.to_ivalue()) + jit_v = torch.jit.script(v) expected_itos = ['', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} assert not v.is_jitable - assert v.to_ivalue().is_jitable + # Call the __prepare_scriptable__() func and convert the building block to the torbhind version + # Not expect users to use the torchbind version on eager mode but still need a CI test here. + assert v.__prepare_scriptable__().is_jitable self.assertEqual(jit_v.get_itos(), expected_itos) self.assertEqual(dict(jit_v.get_stoi()), expected_stoi) @@ -121,7 +123,7 @@ def test_vocab_forward(self): c = OrderedDict(sorted_by_freq_tuples) v = vocab(c) - jit_v = torch.jit.script(v.to_ivalue()) + jit_v = torch.jit.script(v) tokens = ['b', 'a', 'c'] expected_indices = [2, 1, 3] @@ -208,7 +210,9 @@ def test_vocab_load_and_save(self): with self.subTest('torchscript'): vocab_path = os.path.join(self.test_dir, 'vocab_torchscript.pt') - torch.save(v.to_ivalue(), vocab_path) + # Call the __prepare_scriptable__() func and convert the building block to the torbhind version + # Not expect users to use the torchbind version on eager mode but still need a CI test here. + torch.save(v.__prepare_scriptable__(), vocab_path) loaded_v = torch.load(vocab_path) self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi) diff --git a/test/experimental/test_with_asset.py b/test/experimental/test_with_asset.py index 50c01935a9..f900fb6752 100644 --- a/test/experimental/test_with_asset.py +++ b/test/experimental/test_with_asset.py @@ -82,7 +82,7 @@ def test_vocab_transform(self): vocab_transform = VocabTransform(load_vocab_from_file(f)) self.assertEqual(vocab_transform(['of', 'that', 'new']), [7, 18, 24]) - jit_vocab_transform = torch.jit.script(vocab_transform.to_ivalue()) + jit_vocab_transform = torch.jit.script(vocab_transform) self.assertEqual(jit_vocab_transform(['of', 'that', 'new', 'that']), [7, 18, 24, 18]) @@ -128,7 +128,7 @@ def test_glove(self): data_path = os.path.join(dir_name, asset_name) shutil.copy(asset_path, data_path) vectors_obj = GloVe(root=dir_name, validate_file=False) - jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue()) + jit_vectors_obj = torch.jit.script(vectors_obj) # The first 3 entries in each vector. expected_glove = { @@ -191,7 +191,7 @@ def test_vocab_from_raw_text_file(self): asset_path = get_asset_path(asset_name) with open(asset_path, 'r') as f: tokenizer = basic_english_normalize() - jit_tokenizer = torch.jit.script(tokenizer.to_ivalue()) + jit_tokenizer = torch.jit.script(tokenizer) v = build_vocab_from_text_file(f, jit_tokenizer, unk_token='') expected_itos = ['', "'", 'after', 'talks', '.', 'are', 'at', 'disappointed', 'fears', 'federal', 'firm', 'for', 'mogul', 'n', 'newall', 'parent', @@ -243,7 +243,7 @@ def test_text_sequential_transform(self): asset_path = get_asset_path(asset_name) with open(asset_path, 'r') as f: pipeline = TextSequentialTransforms(basic_english_normalize(), load_vocab_from_file(f)) - jit_pipeline = torch.jit.script(pipeline.to_ivalue()) + jit_pipeline = torch.jit.script(pipeline) self.assertEqual(pipeline('of that new'), [7, 18, 24]) self.assertEqual(jit_pipeline('of that new'), [7, 18, 24]) @@ -270,7 +270,7 @@ def test_fast_text(self): data_path = os.path.join(dir_name, asset_name) shutil.copy(asset_path, data_path) vectors_obj = FastText(root=dir_name, validate_file=False) - jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue()) + jit_vectors_obj = torch.jit.script(vectors_obj) # The first 3 entries in each vector. expected_fasttext_simple_en = { diff --git a/torchtext/experimental/transforms.py b/torchtext/experimental/transforms.py index 1f62ea7032..3b542aeb45 100644 --- a/torchtext/experimental/transforms.py +++ b/torchtext/experimental/transforms.py @@ -2,7 +2,6 @@ import torch.nn as nn from typing import List from torchtext._torchtext import RegexTokenizer as RegexTokenizerPybind -from collections import OrderedDict from torch import Tensor from torchtext._torchtext import SentencePiece as SentencePiecePybind import io @@ -50,7 +49,7 @@ def basic_english_normalize(): >>> from torchtext.experimental.transforms import basic_english_normalize >>> test_sample = 'Basic English Normalization for a Line of Text' >>> basic_eng_norm = basic_english_normalize() - >>> jit_basic_eng_norm = torch.jit.script(basic_eng_norm.to_ivalue()) + >>> jit_basic_eng_norm = torch.jit.script(basic_eng_norm) >>> tokens = jit_basic_eng_norm(test_sample) """ @@ -124,10 +123,9 @@ def forward(self, line: str) -> List[str]: return self.regex_tokenizer.forward(line) - def to_ivalue(self): + def __prepare_scriptable__(self): r"""Return a JITable BasicEnglishNormalize. """ - regex_tokenizer = torch.classes.torchtext.RegexTokenizer(self.regex_tokenizer.patterns_, self.regex_tokenizer.replacements_, True) return BasicEnglishNormalize(regex_tokenizer) @@ -159,10 +157,9 @@ def forward(self, line: str) -> List[str]: return self.regex_tokenizer.forward(line) - def to_ivalue(self): + def __prepare_scriptable__(self): r"""Return a JITable RegexTokenizer. """ - regex_tokenizer = torch.classes.torchtext.RegexTokenizer(self.regex_tokenizer.patterns_, self.regex_tokenizer.replacements_, False) return RegexTokenizer(regex_tokenizer) @@ -177,7 +174,7 @@ class TextSequentialTransforms(nn.Sequential): >>> txt_pipeline = TextSequentialTransforms(tokenizer) >>> txt_pipeline('here is an example') ['here', 'is', 'an', 'example'] - >>> jit_txt_pipeline = torch.jit.script(txt_pipeline.to_ivalue()) + >>> jit_txt_pipeline = torch.jit.script(txt_pipeline) """ def forward(self, input: str): @@ -185,17 +182,6 @@ def forward(self, input: str): input = module(input) return input - def to_ivalue(self): - r"""Return a JITable TextSequentialTransforms. - """ - - module_list = [] - for _idx, _module in enumerate(self): - if hasattr(_module, 'to_ivalue'): - _module = _module.to_ivalue() - module_list.append((str(_idx), _module)) - return TextSequentialTransforms(OrderedDict(module_list)) - PRETRAINED_SP_MODEL = { 'text_unigram_15000': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_spm/text_unigram_15000.model', @@ -263,7 +249,7 @@ def sentencepiece_tokenizer(sp_model): >>> import torch >>> from torchtext.experimental.transforms import sentencepiece_tokenizer >>> spm_tokenizer = sentencepiece_tokenizer('m_user.model') - >>> jit_spm_tokenizer = torch.jit.script(spm_tokenizer.to_ivalue()) + >>> jit_spm_tokenizer = torch.jit.script(spm_tokenizer) """ spm = load_sp_model(sp_model) @@ -308,7 +294,7 @@ def decode(self, tokens: List[str]) -> str: return self.sp_model.DecodePieces(tokens) - def to_ivalue(self): + def __prepare_scriptable__(self): torchbind_spm = torch.classes.torchtext.SentencePiece(self.sp_model._return_content()) return SentencePieceTokenizer(torchbind_spm) @@ -323,7 +309,7 @@ def sentencepiece_processor(sp_model): >>> import torch >>> from torchtext.experimental.transforms import sentencepiece_processor >>> spm_processor = sentencepiece_processor('m_user.model') - >>> jit_spm_processor = torch.jit.script(spm_processor.to_ivalue()) + >>> jit_spm_processor = torch.jit.script(spm_processor) """ spm = load_sp_model(sp_model) @@ -366,7 +352,7 @@ def decode(self, ids: List[int]) -> str: return self.sp_model.DecodeIds(ids) - def to_ivalue(self): + def __prepare_scriptable__(self): torchbind_spm = torch.classes.torchtext.SentencePiece(self.sp_model._return_content()) return SentencePieceProcessor(torchbind_spm) @@ -382,7 +368,7 @@ class VocabTransform(nn.Module): >>> from torchtext.experimental.vocab import vocab_from_file_object >>> f = open('vocab.txt', 'r') >>> vocab_transform = VocabTransform(vocab_from_file_object(f)) - >>> jit_vocab_transform = torch.jit.script(vocab_transform.to_ivalue()) + >>> jit_vocab_transform = torch.jit.script(vocab_transform) """ def __init__(self, vocab): @@ -402,12 +388,6 @@ def forward(self, tokens: List[str]) -> List[int]: return self.vocab.lookup_indices(tokens) - def to_ivalue(self): - if hasattr(self.vocab, 'to_ivalue'): - vocab = self.vocab.to_ivalue() - return VocabTransform(vocab) - return self - class VectorTransform(nn.Module): r"""Vector transform @@ -419,7 +399,7 @@ class VectorTransform(nn.Module): >>> import torch >>> from torchtext.experimental.vectors import FastText >>> vector_transform = VectorTransform(FastText()) - >>> jit_vector_transform = torch.jit.script(vector_transform.to_ivalue()) + >>> jit_vector_transform = torch.jit.script(vector_transform) """ def __init__(self, vector): @@ -438,9 +418,3 @@ def forward(self, tokens: List[str]) -> Tensor: """ return self.vector.lookup_vectors(tokens) - - def to_ivalue(self): - if hasattr(self.vector, 'to_ivalue'): - vector = self.vector.to_ivalue() - return VectorTransform(vector) - return self diff --git a/torchtext/experimental/vectors.py b/torchtext/experimental/vectors.py index 72bae2351b..e7779d9ad4 100644 --- a/torchtext/experimental/vectors.py +++ b/torchtext/experimental/vectors.py @@ -285,7 +285,7 @@ def lookup_vectors(self, tokens: List[str]) -> Tensor: return self.vectors.lookup_vectors(tokens) - def to_ivalue(self): + def __prepare_scriptable__(self): r"""Return a JITable Vectors. """ stoi = self.vectors.get_stoi() diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 6883326938..606965daa9 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -43,7 +43,7 @@ def build_vocab_from_text_file(file_object, jited_tokenizer, min_freq=1, unk_tok >>> f = open('vocab.txt', 'r') >>> tokenizer = basic_english_normalize() >>> tokenizer = basic_english_normalize() - >>> jit_tokenizer = torch.jit.script(tokenizer.to_ivalue()) + >>> jit_tokenizer = torch.jit.script(tokenizer) >>> v = build_vocab_from_text_file(f, jit_tokenizer) """ vocab_obj = _build_vocab_from_text_file(file_object.name, unk_token, min_freq, num_cpus, jited_tokenizer) @@ -264,7 +264,7 @@ def get_itos(self) -> List[str]: """ return self.vocab.get_itos() - def to_ivalue(self): + def __prepare_scriptable__(self): r"""Return a JITable Vocab. """ cpp_vocab = torch.classes.torchtext.Vocab(self.vocab.itos_, self.vocab.unk_token_)