Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] add buffered chunked streaming for nemo force aligner #6185

Merged
merged 4 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions nemo/collections/asr/parts/utils/streaming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,7 @@ def __init__(
)

self.asr_model = asr_model
self.decoder = asr_model.decoder

self.batch_size = batch_size
self.all_logits = []
Expand All @@ -716,6 +717,7 @@ def __init__(
self.frame_buffers = []
self.reset()
cfg = copy.deepcopy(asr_model._cfg)
self.cfg = cfg
self.frame_len = frame_len
OmegaConf.set_struct(cfg.preprocessor, False)

Expand All @@ -725,6 +727,7 @@ def __init__(
cfg.preprocessor.normalize = "None"
self.raw_preprocessor = EncDecCTCModelBPE.from_config_dict(cfg.preprocessor)
self.raw_preprocessor.to(asr_model.device)
self.preprocessor = self.raw_preprocessor

def reset(self):
"""
Expand All @@ -750,17 +753,17 @@ def set_frame_reader(self, frame_reader):
self.frame_bufferer.set_frame_reader(frame_reader)

@torch.no_grad()
def infer_logits(self):
def infer_logits(self, keep_logits=False):
frame_buffers = self.frame_bufferer.get_buffers_batch()

while len(frame_buffers) > 0:
self.frame_buffers += frame_buffers[:]
self.data_layer.set_signal(frame_buffers[:])
self._get_batch_preds()
self._get_batch_preds(keep_logits)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to avoid changing signatures of these functions ? Ie set a bool value from config or some other way (class arg or setter function) and rest of the functions just use that ?

frame_buffers = self.frame_bufferer.get_buffers_batch()

@torch.no_grad()
def _get_batch_preds(self):
def _get_batch_preds(self, keep_logits=False):
device = self.asr_model.device
for batch in iter(self.data_loader):

Expand All @@ -772,19 +775,32 @@ def _get_batch_preds(self):
preds = torch.unbind(predictions)
for pred in preds:
self.all_preds.append(pred.cpu().numpy())
del log_probs
if keep_logits:
log_probs = torch.unbind(log_probs)
for log_prob in log_probs:
self.all_logits.append(log_prob.cpu())
else:
del log_probs
del encoded_len
del predictions

def transcribe(
self, tokens_per_chunk: int, delay: int,
):
self.infer_logits()
def transcribe(self, tokens_per_chunk: int, delay: int, keep_logits=False):
self.infer_logits(keep_logits)
self.unmerged = []
for pred in self.all_preds:
decoded = pred.tolist()
self.unmerged += decoded[len(decoded) - 1 - delay : len(decoded) - 1 - delay + tokens_per_chunk]
return self.greedy_merge(self.unmerged)
hypothesis = self.greedy_merge(self.unmerged)
if not keep_logits:
return hypothesis

all_logits = []
for log_prob in self.all_logits:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you put this in a function? I feel like it gets repeated a lot

T = log_prob.shape[0]
log_prob = log_prob[T - 1 - delay : T - 1 - delay + tokens_per_chunk, :]
all_logits.append(log_prob)
all_logits = torch.concat(all_logits, 0)
return hypothesis, all_logits

def greedy_merge(self, preds):
decoded_prediction = []
Expand Down
6 changes: 6 additions & 0 deletions tools/nemo_forced_aligner/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ Call the `align.py` script, specifying the parameters as follows:

* **[OPTIONAL]** `minimum_timestamp_duration`: a float indicating a minimum duration (in seconds) for timestamps in the CTM. If any line in the CTM has a duration lower than the `minimum_timestamp_duration`, it will be enlarged from the middle outwards until it meets the minimum_timestamp_duration, or reaches the beginning or end of the audio file. Note that this may cause timestamps to overlap. (Default: 0, i.e. no modifications to predicted duration).

* **[OPTIONAL]** `use_buffered_chunked_streaming`: a flag to indicate whether to do buffered chunk streaming. Notice only CTC models (e.g., stt_en_citrinet_1024_gamma_0_25)with `per_feature` preprocessor are supported. The below two params are needed if this option set to `True`.

* **[OPTIONAL]** `chunk_len_in_secs`: the chunk size for buffered chunked streaming inference. Default is 1.6 seconds.

* **[OPTIONAL]** `total_buffer_in_secs`: the buffer size for buffered chunked streaming inference. Default is 4.0 seconds.

# Input manifest file format
By default, NFA needs to be provided with a 'manifest' file where each line specifies the absolute "audio_filepath" and "text" of each utterance that you wish to produce alignments for, like the format below:
```json
Expand Down
62 changes: 60 additions & 2 deletions tools/nemo_forced_aligner/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import math
import os
from dataclasses import dataclass, is_dataclass
from typing import Optional
Expand All @@ -29,11 +31,11 @@
from utils.viterbi_decoding import viterbi_decoding

from nemo.collections.asr.models.ctc_models import EncDecCTCModel
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR
from nemo.collections.asr.parts.utils.transcribe_utils import setup_model
from nemo.core.config import hydra_runner
from nemo.utils import logging


"""
Align the utterances in manifest_filepath.
Results are saved in ctm files in output_dir.
Expand Down Expand Up @@ -82,6 +84,16 @@
line in the CTM has a duration lower than the `minimum_timestamp_duration`, it will be enlarged from the
middle outwards until it meets the minimum_timestamp_duration, or reaches the beginning or end of the audio
file. Note that this may cause timestamps to overlap.
use_buffered_infer: False, if set True, using streaming to do get the logits for alignment
This flag is useful when aligning large audio file.
However, currently the chunk streaming inference does not support batch inference,
which means even you set batch_size > 1, it will only infer one by one instead of doing
the whole batch inference together.
chunk_len_in_secs: float chunk length in seconds
total_buffer_in_secs: float Length of buffer (chunk + left and right padding) in seconds
chunk_batch_size: int batch size for buffered chunk inference,
which will cut one audio into segments and do inference on chunk_batch_size segments at a time
"""


Expand All @@ -104,6 +116,12 @@ class AlignmentConfig:
minimum_timestamp_duration: float = 0
audio_filepath_parts_in_utt_id: int = 1

# Buffered chunked streaming configs
use_buffered_chunked_streaming: bool = False
chunk_len_in_secs: float = 1.6
total_buffer_in_secs: float = 4.0
chunk_batch_size: int = 32


@hydra_runner(config_name="AlignmentConfig", schema=AlignmentConfig)
def main(cfg: AlignmentConfig):
Expand Down Expand Up @@ -194,6 +212,41 @@ def main(cfg: AlignmentConfig):
"This may cause the alignments for some tokens/words/additional segments to be overlapping."
)

buffered_chunk_params = {}
if cfg.use_buffered_chunked_streaming:
model_cfg = copy.deepcopy(model._cfg)

OmegaConf.set_struct(model_cfg.preprocessor, False)
# some changes for streaming scenario
model_cfg.preprocessor.dither = 0.0
model_cfg.preprocessor.pad_to = 0

if model_cfg.preprocessor.normalize != "per_feature":
logging.error(
"Only EncDecCTCModelBPE models trained with per_feature normalization are supported currently"
)
# Disable config overwriting
OmegaConf.set_struct(model_cfg.preprocessor, True)

feature_stride = model_cfg.preprocessor['window_stride']
model_stride_in_secs = feature_stride * cfg.model_downsample_factor
total_buffer = cfg.total_buffer_in_secs
chunk_len = float(cfg.chunk_len_in_secs)
tokens_per_chunk = math.ceil(chunk_len / model_stride_in_secs)
mid_delay = math.ceil((chunk_len + (total_buffer - chunk_len) / 2) / model_stride_in_secs)
logging.info(f"tokens_per_chunk is {tokens_per_chunk}, mid_delay is {mid_delay}")

model = FrameBatchASR(
asr_model=model,
frame_len=chunk_len,
total_buffer=cfg.total_buffer_in_secs,
batch_size=cfg.chunk_batch_size,
)
buffered_chunk_params = {
"delay": mid_delay,
"model_stride_in_secs": model_stride_in_secs,
"tokens_per_chunk": tokens_per_chunk,
}
# get start and end line IDs of batches
starts, ends = get_batch_starts_ends(cfg.manifest_filepath, cfg.batch_size)

Expand All @@ -217,7 +270,12 @@ def main(cfg: AlignmentConfig):
segment_info_batch,
pred_text_batch,
) = get_batch_tensors_and_boundary_info(
manifest_lines_batch, model, cfg.additional_ctm_grouping_separator, cfg.align_using_pred_text,
manifest_lines_batch,
model,
cfg.additional_ctm_grouping_separator,
cfg.align_using_pred_text,
cfg.use_buffered_chunked_streaming,
buffered_chunk_params,
)

if cfg.align_using_pred_text:
Expand Down
42 changes: 32 additions & 10 deletions tools/nemo_forced_aligner/utils/data_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import soundfile as sf
import torch
from tqdm.auto import tqdm
from utils.constants import BLANK_TOKEN, SPACE_TOKEN, V_NEGATIVE_NUM


Expand Down Expand Up @@ -140,8 +141,10 @@ def get_y_and_boundary_info_for_utt(text, model, separator):
segments = [seg.strip() for seg in segments]

if hasattr(model, 'tokenizer'):

BLANK_ID = len(model.decoder.vocabulary) # TODO: check
if hasattr(model, 'blank_id'):
BLANK_ID = model.blank_id
else:
BLANK_ID = len(model.decoder.vocabulary) # TODO: check

y_token_ids_with_blanks = [BLANK_ID]
token_info = [{"text": BLANK_TOKEN, "s_start": 0, "s_end": 0,}]
Expand Down Expand Up @@ -283,7 +286,14 @@ def get_y_and_boundary_info_for_utt(text, model, separator):
raise RuntimeError("Cannot get tokens of this model.")


def get_batch_tensors_and_boundary_info(manifest_lines_batch, model, separator, align_using_pred_text):
def get_batch_tensors_and_boundary_info(
manifest_lines_batch,
model,
separator,
align_using_pred_text,
use_buffered_chunked_streaming=False,
buffered_chunk_params={},
):
"""
Returns:
log_probs, y, T, U (y and U are s.t. every other token is a blank) - these are the tensors we will need
Expand All @@ -299,16 +309,28 @@ def get_batch_tensors_and_boundary_info(manifest_lines_batch, model, separator,
# and (optionally) the predicted ASR text from the hypotheses
audio_filepaths_batch = [line["audio_filepath"] for line in manifest_lines_batch]
B = len(audio_filepaths_batch)
with torch.no_grad():
hypotheses = model.transcribe(audio_filepaths_batch, return_hypotheses=True, batch_size=B)

log_probs_list_batch = []
T_list_batch = []
pred_text_batch = []
for hypothesis in hypotheses:
log_probs_list_batch.append(hypothesis.y_sequence)
T_list_batch.append(hypothesis.y_sequence.shape[0])
pred_text_batch.append(hypothesis.text)

if not use_buffered_chunked_streaming:
with torch.no_grad():
hypotheses = model.transcribe(audio_filepaths_batch, return_hypotheses=True, batch_size=B)
for hypothesis in hypotheses:
log_probs_list_batch.append(hypothesis.y_sequence)
T_list_batch.append(hypothesis.y_sequence.shape[0])
pred_text_batch.append(hypothesis.text)
else:
delay = buffered_chunk_params["delay"]
model_stride_in_secs = buffered_chunk_params["model_stride_in_secs"]
tokens_per_chunk = buffered_chunk_params["tokens_per_chunk"]
for l in tqdm(audio_filepaths_batch, desc="Sample:"):
model.reset()
model.read_audio_file(l, delay, model_stride_in_secs)
hyp, logits = model.transcribe(tokens_per_chunk, delay, keep_logits=True)
log_probs_list_batch.append(logits)
T_list_batch.append(logits.shape[0])
pred_text_batch.append(hyp)

# we loop over every line in the manifest that is in our current batch,
# and record the y (list of tokens, including blanks), U (list of lengths of y) and
Expand Down