From 5fb7dd619293dcd1cc02c6371c4079c22a40a23b Mon Sep 17 00:00:00 2001 From: ftshijt <728307998@qq.com> Date: Wed, 27 Apr 2022 00:53:46 -0400 Subject: [PATCH 1/4] remove requirement for src_token_list --- espnet2/tasks/st.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/espnet2/tasks/st.py b/espnet2/tasks/st.py index 182a335cc56..2b992f0be4e 100644 --- a/espnet2/tasks/st.py +++ b/espnet2/tasks/st.py @@ -206,7 +206,7 @@ def add_task_arguments(cls, parser: argparse.ArgumentParser): # NOTE(kamo): add_arguments(..., required=True) can't be used # to provide --print_config mode. Instead of it, do as required = parser.get_default("required") - required += ["src_token_list", "token_list"] + required += ["token_list"] group.add_argument( "--token_list", From d1e8ac3d8717f8717fb645592c25ee8cafc4060c Mon Sep 17 00:00:00 2001 From: ftshijt <728307998@qq.com> Date: Wed, 27 Apr 2022 01:15:18 -0400 Subject: [PATCH 2/4] update test --- test/espnet2/bin/test_st_inference.py | 79 +++++++++++++++++++++++++++ test/espnet2/bin/test_st_train.py | 15 +++++ 2 files changed, 94 insertions(+) create mode 100644 test/espnet2/bin/test_st_inference.py create mode 100644 test/espnet2/bin/test_st_train.py diff --git a/test/espnet2/bin/test_st_inference.py b/test/espnet2/bin/test_st_inference.py new file mode 100644 index 00000000000..7abc3c6ef7d --- /dev/null +++ b/test/espnet2/bin/test_st_inference.py @@ -0,0 +1,79 @@ +from argparse import ArgumentParser +from pathlib import Path +import string + +import numpy as np +import pytest + +from espnet.nets.beam_search import Hypothesis +from espnet2.bin.st_inference import get_parser +from espnet2.bin.st_inference import main +from espnet2.bin.st_inference import Speech2Text +from espnet2.tasks.st import STTask + + +def test_get_parser(): + assert isinstance(get_parser(), ArgumentParser) + + +def test_main(): + with pytest.raises(SystemExit): + main() + + +@pytest.fixture() +def token_list(tmp_path: Path): + with (tmp_path / "tokens.txt").open("w") as f: + f.write("\n") + for c in string.ascii_letters: + f.write(f"{c}\n") + f.write("\n") + f.write("\n") + return tmp_path / "tokens.txt" + + +@pytest.fixture() +def src_token_list(tmp_path: Path): + with (tmp_path / "src_tokens.txt").open("w") as f: + f.write("\n") + for c in string.ascii_letters: + f.write(f"{c}\n") + f.write("\n") + f.write("\n") + return tmp_path / "src_tokens.txt" + + +@pytest.fixture() +def st_config_file(tmp_path: Path, token_list, src_token_list): + # Write default configuration file + STTask.main( + cmd=[ + "--dry_run", + "true", + "--output_dir", + str(tmp_path / "st"), + "--token_list", + str(token_list), + "--src_token_list", + str(src_token_list), + "--token_type", + "char", + ] + ) + return tmp_path / "st" / "config.yaml" + + +@pytest.mark.execution_timeout(5) +def test_Speech2Text(st_config_file): + speech2text = Speech2Text( + st_train_config=st_config_file, beam_size=1 + ) + speech = np.random.randn(1000) + results = speech2text(speech) + for text, token, token_int, hyp in results: + assert isinstance(text, str) + assert isinstance(token[0], str) + assert isinstance(token_int[0], int) + assert isinstance(hyp, Hypothesis) + + diff --git a/test/espnet2/bin/test_st_train.py b/test/espnet2/bin/test_st_train.py new file mode 100644 index 00000000000..5be899f0a38 --- /dev/null +++ b/test/espnet2/bin/test_st_train.py @@ -0,0 +1,15 @@ +from argparse import ArgumentParser + +import pytest + +from espnet2.bin.st_train import get_parser +from espnet2.bin.st_train import main + + +def test_get_parser(): + assert isinstance(get_parser(), ArgumentParser) + + +def test_main(): + with pytest.raises(SystemExit): + main() From 72b6b21d509a26d30a454525811c3530ee6b297b Mon Sep 17 00:00:00 2001 From: ftshijt <728307998@qq.com> Date: Wed, 27 Apr 2022 01:27:09 -0400 Subject: [PATCH 3/4] add st unit test --- espnet2/st/espnet_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/espnet2/st/espnet_model.py b/espnet2/st/espnet_model.py index e298ef1822d..ee744681bd7 100644 --- a/espnet2/st/espnet_model.py +++ b/espnet2/st/espnet_model.py @@ -53,9 +53,9 @@ def __init__( decoder: AbsDecoder, extra_asr_decoder: Optional[AbsDecoder], extra_mt_decoder: Optional[AbsDecoder], - ctc: CTC, - src_vocab_size: int = 0, - src_token_list: Union[Tuple[str, ...], List[str]] = [], + ctc: Optional[CTC], + src_vocab_size: Optional[int], + src_token_list: Optional[Union[Tuple[str, ...], List[str]]], asr_weight: float = 0.0, mt_weight: float = 0.0, mtlalpha: float = 0.0, From c4b93e8fd870954ec2649abc3fc6172d78d92166 Mon Sep 17 00:00:00 2001 From: ftshijt <728307998@qq.com> Date: Wed, 27 Apr 2022 01:49:00 -0400 Subject: [PATCH 4/4] apply black --- test/espnet2/bin/test_st_inference.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/test/espnet2/bin/test_st_inference.py b/test/espnet2/bin/test_st_inference.py index 7abc3c6ef7d..3910479456b 100644 --- a/test/espnet2/bin/test_st_inference.py +++ b/test/espnet2/bin/test_st_inference.py @@ -65,9 +65,7 @@ def st_config_file(tmp_path: Path, token_list, src_token_list): @pytest.mark.execution_timeout(5) def test_Speech2Text(st_config_file): - speech2text = Speech2Text( - st_train_config=st_config_file, beam_size=1 - ) + speech2text = Speech2Text(st_train_config=st_config_file, beam_size=1) speech = np.random.randn(1000) results = speech2text(speech) for text, token, token_int, hyp in results: @@ -75,5 +73,3 @@ def test_Speech2Text(st_config_file): assert isinstance(token[0], str) assert isinstance(token_int[0], int) assert isinstance(hyp, Hypothesis) - -