Skip to content

Commit

Permalink
Update asr_inference_streaming.py
Browse files Browse the repository at this point in the history
Bugfix in streaming inference espnet#4216
  • Loading branch information
espnetUser authored May 10, 2022
1 parent 61b5013 commit 3aafdb9
Showing 1 changed file with 75 additions and 21 deletions.
96 changes: 75 additions & 21 deletions espnet2/bin/asr_inference_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from espnet2.utils.types import str2triple_str
from espnet2.utils.types import str_or_none
import logging
import math
import numpy as np
from pathlib import Path
import sys
Expand Down Expand Up @@ -181,6 +182,21 @@ def __init__(
self.device = device
self.dtype = dtype
self.nbest = nbest
if "n_fft" in asr_train_args.frontend_conf:
self.n_fft = asr_train_args.frontend_conf["n_fft"]
else:
self.n_fft = 512
if "hop_length" in asr_train_args.frontend_conf:
self.hop_length = asr_train_args.frontend_conf["hop_length"]
else:
self.hop_length = 128
if (
"win_length" in asr_train_args.frontend_conf
and asr_train_args.frontend_conf["win_length"] is not None
):
self.win_length = asr_train_args.frontend_conf["win_length"]
else:
self.win_length = self.n_fft

self.reset()

Expand All @@ -196,15 +212,34 @@ def apply_frontend(
buf = prev_states["waveform_buffer"]
speech = torch.cat([buf, speech], dim=0)

has_enough_samples = False if speech.size(0) <= self.win_length else True
if not has_enough_samples:
if is_final:
pad = torch.zeros(self.win_length - speech.size(0))
speech = torch.cat([speech, pad], dim=0)
else:
feats = None
feats_lengths = None
next_states = {"waveform_buffer": speech.clone()}
return feats, feats_lengths, next_states

if is_final:
speech_to_process = speech
waveform_buffer = None
else:
n_frames = (speech.size(0) - 384) // 128
n_residual = (speech.size(0) - 384) % 128
speech_to_process = speech.narrow(0, 0, 384 + n_frames * 128)
n_frames = (
speech.size(0) - (self.win_length - self.hop_length)
) // self.hop_length
n_residual = (
speech.size(0) - (self.win_length - self.hop_length)
) % self.hop_length
speech_to_process = speech.narrow(
0, 0, (self.win_length - self.hop_length) + n_frames * self.hop_length
)
waveform_buffer = speech.narrow(
0, speech.size(0) - 384 - n_residual, 384 + n_residual
0,
speech.size(0) - (self.win_length - self.hop_length) - n_residual,
(self.win_length - self.hop_length) + n_residual,
).clone()

# data: (Nsamples,) -> (1, Nsamples)
Expand All @@ -229,12 +264,27 @@ def apply_frontend(
if prev_states is None:
pass
else:
feats = feats.narrow(1, 2, feats.size(1) - 2)
feats = feats.narrow(
1,
math.ceil(math.ceil(self.win_length / self.hop_length) / 2),
feats.size(1)
- math.ceil(math.ceil(self.win_length / self.hop_length) / 2),
)
else:
if prev_states is None:
feats = feats.narrow(1, 0, feats.size(1) - 2)
feats = feats.narrow(
1,
0,
feats.size(1)
- math.ceil(math.ceil(self.win_length / self.hop_length) / 2),
)
else:
feats = feats.narrow(1, 2, feats.size(1) - 4)
feats = feats.narrow(
1,
math.ceil(math.ceil(self.win_length / self.hop_length) / 2),
feats.size(1)
- 2 * math.ceil(math.ceil(self.win_length / self.hop_length) / 2),
)

feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))

Expand Down Expand Up @@ -265,21 +315,25 @@ def __call__(
feats, feats_lengths, self.frontend_states = self.apply_frontend(
speech, self.frontend_states, is_final=is_final
)
enc, _, self.encoder_states = self.asr_model.encoder(
feats,
feats_lengths,
self.encoder_states,
is_final=is_final,
infer_mode=True,
)
nbest_hyps = self.beam_search(
x=enc[0],
maxlenratio=self.maxlenratio,
minlenratio=self.minlenratio,
is_final=is_final,
)

ret = self.assemble_hyps(nbest_hyps)
if feats is not None:
enc, _, self.encoder_states = self.asr_model.encoder(
feats,
feats_lengths,
self.encoder_states,
is_final=is_final,
infer_mode=True,
)
nbest_hyps = self.beam_search(
x=enc[0],
maxlenratio=self.maxlenratio,
minlenratio=self.minlenratio,
is_final=is_final,
)
ret = self.assemble_hyps(nbest_hyps)
else:
ret = []

if is_final:
self.reset()
return ret
Expand Down

0 comments on commit 3aafdb9

Please sign in to comment.