From ec518ccc74b85e3b50304ab70ae5a1f069df0038 Mon Sep 17 00:00:00 2001 From: Gunnar Thor Date: Wed, 23 Feb 2022 11:31:56 +0000 Subject: [PATCH] Add progress bar to phonemization --- .../pyscripts/utils/convert_text_to_phn.py | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/egs2/TEMPLATE/asr1/pyscripts/utils/convert_text_to_phn.py b/egs2/TEMPLATE/asr1/pyscripts/utils/convert_text_to_phn.py index 21f8f4daf46..bb8be8b861b 100755 --- a/egs2/TEMPLATE/asr1/pyscripts/utils/convert_text_to_phn.py +++ b/egs2/TEMPLATE/asr1/pyscripts/utils/convert_text_to_phn.py @@ -7,9 +7,11 @@ import argparse import codecs +from tqdm import tqdm +import contextlib from joblib import delayed -from joblib import Parallel +from joblib import Parallel, parallel from espnet2.text.cleaner import TextCleaner from espnet2.text.phoneme_tokenizer import PhonemeTokenizer @@ -34,13 +36,29 @@ def main(): text = {line.split()[0]: " ".join(line.split()[1:]) for line in lines} if cleaner is not None: text = {k: cleaner(v) for k, v in text.items()} - phns_list = Parallel(n_jobs=args.nj)( - [delayed(phoneme_tokenizer.text2tokens)(sentence) for sentence in text.values()] - ) + with tqdm_joblib(tqdm(total=len(text.values()), desc="Phonemizing")) as progress_bar: + phns_list = Parallel(n_jobs=args.nj)( + [delayed(phoneme_tokenizer.text2tokens)(sentence) for sentence in text.values()] + ) with codecs.open(args.out_text, "w", encoding="utf8") as g: for utt_id, phns in zip(text.keys(), phns_list): g.write(f"{utt_id} " + " ".join(phns) + "\n") +@contextlib.contextmanager +def tqdm_joblib(tqdm_object): + """Context manager to patch joblib to report into tqdm progress bar given as argument""" + class TqdmBatchCompletionCallback(parallel.BatchCompletionCallBack): + def __call__(self, *args, **kwargs): + tqdm_object.update(n=self.batch_size) + return super().__call__(*args, **kwargs) + + old_batch_callback = parallel.BatchCompletionCallBack + parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback + try: + yield tqdm_object + finally: + parallel.BatchCompletionCallBack = old_batch_callback + tqdm_object.close() if __name__ == "__main__": main()