Skip to content

Commit

Permalink
Merge pull request espnet#4352 from espnetUser/master
Browse files Browse the repository at this point in the history
Add unit test to streaming ASR inference
  • Loading branch information
sw005320 authored May 11, 2022
2 parents 39bae01 + 52c238d commit 1b12410
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 31 deletions.
96 changes: 75 additions & 21 deletions espnet2/bin/asr_inference_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from espnet2.utils.types import str2triple_str
from espnet2.utils.types import str_or_none
import logging
import math
import numpy as np
from pathlib import Path
import sys
Expand Down Expand Up @@ -181,6 +182,21 @@ def __init__(
self.device = device
self.dtype = dtype
self.nbest = nbest
if "n_fft" in asr_train_args.frontend_conf:
self.n_fft = asr_train_args.frontend_conf["n_fft"]
else:
self.n_fft = 512
if "hop_length" in asr_train_args.frontend_conf:
self.hop_length = asr_train_args.frontend_conf["hop_length"]
else:
self.hop_length = 128
if (
"win_length" in asr_train_args.frontend_conf
and asr_train_args.frontend_conf["win_length"] is not None
):
self.win_length = asr_train_args.frontend_conf["win_length"]
else:
self.win_length = self.n_fft

self.reset()

Expand All @@ -196,15 +212,34 @@ def apply_frontend(
buf = prev_states["waveform_buffer"]
speech = torch.cat([buf, speech], dim=0)

has_enough_samples = False if speech.size(0) <= self.win_length else True
if not has_enough_samples:
if is_final:
pad = torch.zeros(self.win_length - speech.size(0), dtype=speech.dtype)
speech = torch.cat([speech, pad], dim=0)
else:
feats = None
feats_lengths = None
next_states = {"waveform_buffer": speech.clone()}
return feats, feats_lengths, next_states

if is_final:
speech_to_process = speech
waveform_buffer = None
else:
n_frames = (speech.size(0) - 384) // 128
n_residual = (speech.size(0) - 384) % 128
speech_to_process = speech.narrow(0, 0, 384 + n_frames * 128)
n_frames = (
speech.size(0) - (self.win_length - self.hop_length)
) // self.hop_length
n_residual = (
speech.size(0) - (self.win_length - self.hop_length)
) % self.hop_length
speech_to_process = speech.narrow(
0, 0, (self.win_length - self.hop_length) + n_frames * self.hop_length
)
waveform_buffer = speech.narrow(
0, speech.size(0) - 384 - n_residual, 384 + n_residual
0,
speech.size(0) - (self.win_length - self.hop_length) - n_residual,
(self.win_length - self.hop_length) + n_residual,
).clone()

# data: (Nsamples,) -> (1, Nsamples)
Expand All @@ -229,12 +264,27 @@ def apply_frontend(
if prev_states is None:
pass
else:
feats = feats.narrow(1, 2, feats.size(1) - 2)
feats = feats.narrow(
1,
math.ceil(math.ceil(self.win_length / self.hop_length) / 2),
feats.size(1)
- math.ceil(math.ceil(self.win_length / self.hop_length) / 2),
)
else:
if prev_states is None:
feats = feats.narrow(1, 0, feats.size(1) - 2)
feats = feats.narrow(
1,
0,
feats.size(1)
- math.ceil(math.ceil(self.win_length / self.hop_length) / 2),
)
else:
feats = feats.narrow(1, 2, feats.size(1) - 4)
feats = feats.narrow(
1,
math.ceil(math.ceil(self.win_length / self.hop_length) / 2),
feats.size(1)
- 2 * math.ceil(math.ceil(self.win_length / self.hop_length) / 2),
)

feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))

Expand Down Expand Up @@ -265,21 +315,25 @@ def __call__(
feats, feats_lengths, self.frontend_states = self.apply_frontend(
speech, self.frontend_states, is_final=is_final
)
enc, _, self.encoder_states = self.asr_model.encoder(
feats,
feats_lengths,
self.encoder_states,
is_final=is_final,
infer_mode=True,
)
nbest_hyps = self.beam_search(
x=enc[0],
maxlenratio=self.maxlenratio,
minlenratio=self.minlenratio,
is_final=is_final,
)

ret = self.assemble_hyps(nbest_hyps)
if feats is not None:
enc, _, self.encoder_states = self.asr_model.encoder(
feats,
feats_lengths,
self.encoder_states,
is_final=is_final,
infer_mode=True,
)
nbest_hyps = self.beam_search(
x=enc[0],
maxlenratio=self.maxlenratio,
minlenratio=self.minlenratio,
is_final=is_final,
)
ret = self.assemble_hyps(nbest_hyps)
else:
ret = []

if is_final:
self.reset()
return ret
Expand Down
84 changes: 74 additions & 10 deletions test/espnet2/bin/test_asr_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

import numpy as np
import pytest
import yaml

from espnet.nets.beam_search import Hypothesis
from espnet2.bin.asr_inference import get_parser
from espnet2.bin.asr_inference import main
from espnet2.bin.asr_inference import Speech2Text
from espnet2.bin.asr_inference_streaming import Speech2TextStreaming
from espnet2.tasks.asr import ASRTask
from espnet2.tasks.enh_s2t import EnhS2TTask
from espnet2.tasks.lm import LMTask
Expand Down Expand Up @@ -99,26 +101,88 @@ def asr_config_file_streaming(tmp_path: Path, token_list):
"char",
"--decoder",
"transformer",
"--encoder",
"contextual_block_transformer",
]
)
return tmp_path / "asr_streaming" / "config.yaml"


@pytest.mark.execution_timeout(10)
@pytest.mark.execution_timeout(20)
def test_Speech2Text_streaming(asr_config_file_streaming, lm_config_file):
speech2text = Speech2Text(
file = open(asr_config_file_streaming, "r", encoding="utf-8")
asr_train_config = file.read()
asr_train_config = yaml.full_load(asr_train_config)
asr_train_config["frontend"] = "default"
asr_train_config["encoder_conf"] = {
"look_ahead": 16,
"hop_size": 16,
"block_size": 40,
}
# Change the configuration file
with open(asr_config_file_streaming, "w", encoding="utf-8") as files:
yaml.dump(asr_train_config, files)
speech2text = Speech2TextStreaming(
asr_train_config=asr_config_file_streaming,
lm_train_config=lm_config_file,
beam_size=1,
streaming=True,
)
speech = np.random.randn(100000)
results = speech2text(speech)
for text, token, token_int, hyp in results:
assert isinstance(text, str)
assert isinstance(token[0], str)
assert isinstance(token_int[0], int)
assert isinstance(hyp, Hypothesis)
speech = np.random.randn(10000)
for sim_chunk_length in [1, 32, 128, 512, 1024, 2048]:
if (len(speech) // sim_chunk_length) > 1:
for i in range(len(speech) // sim_chunk_length):
speech2text(
speech=speech[i * sim_chunk_length : (i + 1) * sim_chunk_length],
is_final=False,
)
results = speech2text(
speech[(i + 1) * sim_chunk_length : len(speech)], is_final=True
)
else:
results = speech2text(speech)
for text, token, token_int, hyp in results:
assert isinstance(text, str)
assert isinstance(token[0], str)
assert isinstance(token_int[0], int)
assert isinstance(hyp, Hypothesis)

# Test edge case: https://github.com/espnet/espnet/pull/4216
file = open(asr_config_file_streaming, "r", encoding="utf-8")
asr_train_config = file.read()
asr_train_config = yaml.full_load(asr_train_config)
asr_train_config["frontend"] = "default"
asr_train_config["frontend_conf"] = {
"n_fft": 256,
"win_length": 256,
"hop_length": 128,
}
# Change the configuration file
with open(asr_config_file_streaming, "w", encoding="utf-8") as files:
yaml.dump(asr_train_config, files)
speech2text = Speech2TextStreaming(
asr_train_config=asr_config_file_streaming,
lm_train_config=lm_config_file,
beam_size=1,
)
# edge case: speech is exactly multiple of sim_chunk_length, e.g., 10240 = 5 x 2048
speech = np.random.randn(10240)
for sim_chunk_length in [1, 32, 64, 128, 512, 1024, 2048]:
if (len(speech) // sim_chunk_length) > 1:
for i in range(len(speech) // sim_chunk_length):
speech2text(
speech=speech[i * sim_chunk_length : (i + 1) * sim_chunk_length],
is_final=False,
)
results = speech2text(
speech[(i + 1) * sim_chunk_length : len(speech)], is_final=True
)
else:
results = speech2text(speech)
for text, token, token_int, hyp in results:
assert isinstance(text, str)
assert isinstance(token[0], str)
assert isinstance(token_int[0], int)
assert isinstance(hyp, Hypothesis)


@pytest.fixture()
Expand Down

0 comments on commit 1b12410

Please sign in to comment.