Skip to content

Commit

Permalink
Merge pull request espnet#4320 from cadia-lvl/add-progress-bar
Browse files Browse the repository at this point in the history
  • Loading branch information
kan-bayashi authored Apr 29, 2022
2 parents 930b380 + 664414c commit b757b89
Showing 1 changed file with 34 additions and 4 deletions.
38 changes: 34 additions & 4 deletions egs2/TEMPLATE/asr1/pyscripts/utils/convert_text_to_phn.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
#!/usr/bin/env python3

# Copyright 2021 Tomoki Hayashi
# Copyright 2021 Tomoki Hayashi and Gunnar Thor
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)

"""Convert kaldi-style text into phonemized sentences."""

import argparse
import codecs
import contextlib

from joblib import delayed
from joblib import Parallel
from joblib import parallel
from tqdm import tqdm

from espnet2.text.cleaner import TextCleaner
from espnet2.text.phoneme_tokenizer import PhonemeTokenizer
Expand All @@ -34,13 +37,40 @@ def main():
text = {line.split()[0]: " ".join(line.split()[1:]) for line in lines}
if cleaner is not None:
text = {k: cleaner(v) for k, v in text.items()}
phns_list = Parallel(n_jobs=args.nj)(
[delayed(phoneme_tokenizer.text2tokens)(sentence) for sentence in text.values()]
)
with tqdm_joblib(tqdm(total=len(text.values()), desc="Phonemizing")):
phns_list = Parallel(n_jobs=args.nj)(
[
delayed(phoneme_tokenizer.text2tokens)(sentence)
for sentence in text.values()
]
)
with codecs.open(args.out_text, "w", encoding="utf8") as g:
for utt_id, phns in zip(text.keys(), phns_list):
g.write(f"{utt_id} " + " ".join(phns) + "\n")


@contextlib.contextmanager
def tqdm_joblib(tqdm_object):
"""Patch joblib to report into tqdm progress bar given as argument.
Reference:
https://stackoverflow.com/questions/24983493
"""

class TqdmBatchCompletionCallback(parallel.BatchCompletionCallBack):
def __call__(self, *args, **kwargs):
tqdm_object.update(n=self.batch_size)
return super().__call__(*args, **kwargs)

old_batch_callback = parallel.BatchCompletionCallBack
parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
try:
yield tqdm_object
finally:
parallel.BatchCompletionCallBack = old_batch_callback
tqdm_object.close()


if __name__ == "__main__":
main()

0 comments on commit b757b89

Please sign in to comment.