Skip to content

Commit

Permalink
Merge pull request espnet#4324 from ftshijt/master
Browse files Browse the repository at this point in the history
Add Test Functions for ST Train and Inference
  • Loading branch information
ftshijt authored Apr 27, 2022
2 parents 0ae3773 + c4b93e8 commit 44971ff
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 4 deletions.
6 changes: 3 additions & 3 deletions espnet2/st/espnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def __init__(
decoder: AbsDecoder,
extra_asr_decoder: Optional[AbsDecoder],
extra_mt_decoder: Optional[AbsDecoder],
ctc: CTC,
src_vocab_size: int = 0,
src_token_list: Union[Tuple[str, ...], List[str]] = [],
ctc: Optional[CTC],
src_vocab_size: Optional[int],
src_token_list: Optional[Union[Tuple[str, ...], List[str]]],
asr_weight: float = 0.0,
mt_weight: float = 0.0,
mtlalpha: float = 0.0,
Expand Down
2 changes: 1 addition & 1 deletion espnet2/tasks/st.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def add_task_arguments(cls, parser: argparse.ArgumentParser):
# NOTE(kamo): add_arguments(..., required=True) can't be used
# to provide --print_config mode. Instead of it, do as
required = parser.get_default("required")
required += ["src_token_list", "token_list"]
required += ["token_list"]

group.add_argument(
"--token_list",
Expand Down
75 changes: 75 additions & 0 deletions test/espnet2/bin/test_st_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from argparse import ArgumentParser
from pathlib import Path
import string

import numpy as np
import pytest

from espnet.nets.beam_search import Hypothesis
from espnet2.bin.st_inference import get_parser
from espnet2.bin.st_inference import main
from espnet2.bin.st_inference import Speech2Text
from espnet2.tasks.st import STTask


def test_get_parser():
assert isinstance(get_parser(), ArgumentParser)


def test_main():
with pytest.raises(SystemExit):
main()


@pytest.fixture()
def token_list(tmp_path: Path):
with (tmp_path / "tokens.txt").open("w") as f:
f.write("<blank>\n")
for c in string.ascii_letters:
f.write(f"{c}\n")
f.write("<unk>\n")
f.write("<sos/eos>\n")
return tmp_path / "tokens.txt"


@pytest.fixture()
def src_token_list(tmp_path: Path):
with (tmp_path / "src_tokens.txt").open("w") as f:
f.write("<blank>\n")
for c in string.ascii_letters:
f.write(f"{c}\n")
f.write("<unk>\n")
f.write("<sos/eos>\n")
return tmp_path / "src_tokens.txt"


@pytest.fixture()
def st_config_file(tmp_path: Path, token_list, src_token_list):
# Write default configuration file
STTask.main(
cmd=[
"--dry_run",
"true",
"--output_dir",
str(tmp_path / "st"),
"--token_list",
str(token_list),
"--src_token_list",
str(src_token_list),
"--token_type",
"char",
]
)
return tmp_path / "st" / "config.yaml"


@pytest.mark.execution_timeout(5)
def test_Speech2Text(st_config_file):
speech2text = Speech2Text(st_train_config=st_config_file, beam_size=1)
speech = np.random.randn(1000)
results = speech2text(speech)
for text, token, token_int, hyp in results:
assert isinstance(text, str)
assert isinstance(token[0], str)
assert isinstance(token_int[0], int)
assert isinstance(hyp, Hypothesis)
15 changes: 15 additions & 0 deletions test/espnet2/bin/test_st_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from argparse import ArgumentParser

import pytest

from espnet2.bin.st_train import get_parser
from espnet2.bin.st_train import main


def test_get_parser():
assert isinstance(get_parser(), ArgumentParser)


def test_main():
with pytest.raises(SystemExit):
main()

0 comments on commit 44971ff

Please sign in to comment.