From e7ec4f6baa5a1d49d61f9dfaa66f1cdcdb230ef8 Mon Sep 17 00:00:00 2001 From: Zhaoheng Ni Date: Tue, 15 Feb 2022 20:01:36 +0000 Subject: [PATCH] refactor eval and pipeline_demo scripts --- examples/asr/emformer_rnnt/eval.py | 12 +++++++++--- examples/asr/emformer_rnnt/pipeline_demo.py | 5 +++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/examples/asr/emformer_rnnt/eval.py b/examples/asr/emformer_rnnt/eval.py index 144085f7ad..bea8c42a19 100644 --- a/examples/asr/emformer_rnnt/eval.py +++ b/examples/asr/emformer_rnnt/eval.py @@ -1,7 +1,13 @@ #!/usr/bin/env python3 +"""Evaluate the lightning module by loading the checkpoint, the SentencePiece model, and the global_stats.json. + +Example: +python eval.py --model-type tedlium3 --checkpoint-path ./experiments/checkpoints/epoch=119-step=254999.ckpt + --dataset-path ./datasets/tedlium --sp-model-path ./spm_bpe_500.model +""" import logging import pathlib -from argparse import ArgumentParser +from argparse import ArgumentParser, RawTextHelpFormatter import torch import torchaudio @@ -11,7 +17,7 @@ from tedlium3.lightning import TEDLIUM3RNNTModule -logger = logging.getLogger() +logger = logging.getLogger(__name__) def compute_word_level_distance(seq1, seq2): @@ -79,7 +85,7 @@ def get_lightning_module(args): def parse_args(): - parser = ArgumentParser() + parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter) parser.add_argument( "--model-type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3, MODEL_TYPE_MUSTC], required=True ) diff --git a/examples/asr/emformer_rnnt/pipeline_demo.py b/examples/asr/emformer_rnnt/pipeline_demo.py index b6316017f8..0b0c7d6957 100644 --- a/examples/asr/emformer_rnnt/pipeline_demo.py +++ b/examples/asr/emformer_rnnt/pipeline_demo.py @@ -1,4 +1,9 @@ #!/usr/bin/env python3 +"""The demo script for testing the pre-trained Emformer RNNT pipelines. + +Example: +python pipeline_demo.py --model-type librispeech --dataset-path ./datasets/librispeech +""" import logging import pathlib from argparse import ArgumentParser