Skip to content

Commit

Permalink
Cherry pick commits in #6601 to main (#6611)
Browse files Browse the repository at this point in the history
* fix write

Signed-off-by: fayejf <fayejf07@gmail.com>

* decoding ctc

Signed-off-by: fayejf <fayejf07@gmail.com>

* temp set rnnt decoding return_best_hypothesis to true

Signed-off-by: fayejf <fayejf07@gmail.com>

* add wer cal back to transcribe_speech as requested

Signed-off-by: fayejf <fayejf07@gmail.com>

* add wer cal back to speech_to_text_buffered_infer_rnnt  as requested

Signed-off-by: fayejf <fayejf07@gmail.com>

* add wer cal back to speech_to_text_buffered_infer_ctc as requested

Signed-off-by: fayejf <fayejf07@gmail.com>

* style fix

Signed-off-by: fayejf <fayejf07@gmail.com>

* reflect change in asr_evaluator

Signed-off-by: fayejf <fayejf07@gmail.com>

* reflect som and vahid comment

Signed-off-by: fayejf <fayejf07@gmail.com>

* remove return_best_hy=true in transcribe_speech

Signed-off-by: fayejf <fayejf07@gmail.com>

* no text skip

Signed-off-by: fayejf <fayejf07@gmail.com>

* revert partial

Signed-off-by: fayejf <fayejf07@gmail.com>

---------

Signed-off-by: fayejf <fayejf07@gmail.com>
  • Loading branch information
fayejf authored May 10, 2023
1 parent fa89ba5 commit f7989f7
Show file tree
Hide file tree
Showing 8 changed files with 276 additions and 160 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
total_buffer_in_secs=4.0 \
chunk_len_in_secs=1.6 \
model_stride=4 \
batch_size=32
batch_size=32 \
clean_groundtruth_text=True \
langid='en'
# NOTE:
You can use `DEBUG=1 python speech_to_text_buffered_infer_ctc.py ...` to print out the
Expand All @@ -45,6 +47,8 @@
import torch
from omegaconf import OmegaConf

from nemo.collections.asr.metrics.wer import CTCDecodingConfig
from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR
from nemo.collections.asr.parts.utils.transcribe_utils import (
compute_output_filename,
Expand Down Expand Up @@ -79,6 +83,9 @@ class TranscriptionConfig:
total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models",

# Decoding strategy for CTC models
decoding: CTCDecodingConfig = CTCDecodingConfig()

# Set `cuda` to int to define CUDA device. If 'None', will look for CUDA
# device anyway, and do inference on CPU only if CUDA device is not found.
# If `cuda` is a negative number, inference will be on CPU only.
Expand All @@ -89,6 +96,12 @@ class TranscriptionConfig:
# Recompute model transcription, even if the output folder exists with scores.
overwrite_transcripts: bool = True

# Config for word / character error rate calculation
calculate_wer: bool = True
clean_groundtruth_text: bool = False
langid: str = "en" # specify this for convert_num_to_words step in groundtruth cleaning
use_cer: bool = False


@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
Expand Down Expand Up @@ -188,11 +201,24 @@ def autocast():
manifest,
filepaths,
)
output_filename = write_transcription(
output_filename, pred_text_attr_name = write_transcription(
hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, compute_timestamps=False
)
logging.info(f"Finished writing predictions to {output_filename}!")

if cfg.calculate_wer:
output_manifest_w_wer, total_res, _ = cal_write_wer(
pred_manifest=output_filename,
pred_text_attr_name=pred_text_attr_name,
clean_groundtruth_text=cfg.clean_groundtruth_text,
langid=cfg.langid,
use_cer=cfg.use_cer,
output_filename=None,
)
if output_manifest_w_wer:
logging.info(f"Writing prediction and error rate of each sample to {output_manifest_w_wer}!")
logging.info(f"{total_res}")

return cfg


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
total_buffer_in_secs=4.0 \
chunk_len_in_secs=1.6 \
model_stride=4 \
batch_size=32
batch_size=32 \
clean_groundtruth_text=True \
langid='en'
# Longer Common Subsequence (LCS) Merge algorithm
Expand Down Expand Up @@ -66,6 +68,7 @@
import torch
from omegaconf import OmegaConf, open_dict

from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer
from nemo.collections.asr.parts.utils.streaming_utils import (
BatchedFrameASRRNNT,
LongestCommonSubsequenceBatchedFrameASRRNNT,
Expand Down Expand Up @@ -101,7 +104,7 @@ class TranscriptionConfig:
# Chunked configs
chunk_len_in_secs: float = 1.6 # Chunk length in seconds
total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models",
model_stride: int = 8 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models

# Set `cuda` to int to define CUDA device. If 'None', will look for CUDA
# device anyway, and do inference on CPU only if CUDA device is not found.
Expand All @@ -120,6 +123,12 @@ class TranscriptionConfig:
merge_algo: Optional[str] = 'middle' # choices=['middle', 'lcs'], choice of algorithm to apply during inference.
lcs_alignment_dir: Optional[str] = None # Path to a directory to store LCS algo alignments

# Config for word / character error rate calculation
calculate_wer: bool = True
clean_groundtruth_text: bool = False
langid: str = "en" # specify this for convert_num_to_words step in groundtruth cleaning
use_cer: bool = False


@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
Expand Down Expand Up @@ -194,9 +203,13 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
decoding_cfg.strategy = "greedy_batch"
decoding_cfg.preserve_alignments = True # required to compute the middle token for transducers.
decoding_cfg.fused_batch_size = -1 # temporarily stop fused batch during inference.
decoding_cfg.beam.return_best_hypothesis = True

asr_model.change_decoding_strategy(decoding_cfg)

with open_dict(cfg):
cfg.decoding = decoding_cfg

feature_stride = model_cfg.preprocessor['window_stride']
model_stride_in_secs = feature_stride * cfg.model_stride
total_buffer = cfg.total_buffer_in_secs
Expand Down Expand Up @@ -242,11 +255,24 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
filepaths=filepaths,
)

output_filename = write_transcription(
output_filename, pred_text_attr_name = write_transcription(
hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, compute_timestamps=False
)
logging.info(f"Finished writing predictions to {output_filename}!")

if cfg.calculate_wer:
output_manifest_w_wer, total_res, _ = cal_write_wer(
pred_manifest=output_filename,
pred_text_attr_name=pred_text_attr_name,
clean_groundtruth_text=cfg.clean_groundtruth_text,
langid=cfg.langid,
use_cer=cfg.use_cer,
output_filename=None,
)
if output_manifest_w_wer:
logging.info(f"Writing prediction and error rate of each sample to {output_manifest_w_wer}!")
logging.info(f"{total_res}")

return cfg


Expand Down
29 changes: 28 additions & 1 deletion examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nemo.collections.asr.metrics.wer import CTCDecodingConfig
from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel
from nemo.collections.asr.modules.conformer_encoder import ConformerChangeConfig
from nemo.collections.asr.parts.utils.eval_utils import cal_write_wer
from nemo.collections.asr.parts.utils.transcribe_utils import (
compute_output_filename,
prepare_audio_data,
Expand Down Expand Up @@ -69,6 +70,11 @@
ctc_decoding: Decoding sub-config for CTC. Refer to documentation for specific values.
rnnt_decoding: Decoding sub-config for RNNT. Refer to documentation for specific values.
calculate_wer: Bool to decide whether to calculate wer/cer at end of this script
clean_groundtruth_text: Bool to clean groundtruth text
langid: Str used for convert_num_to_words during groundtruth cleaning
use_cer: Bool to use Character Error Rate (CER) or Word Error Rate (WER)
# Usage
ASR model can be specified by either "model_path" or "pretrained_name".
Data for transcription can be defined with either "audio_dir" or "dataset_manifest".
Expand All @@ -82,6 +88,8 @@
audio_dir="<remove or path to folder of audio files>" \
dataset_manifest="<remove or path to manifest>" \
output_filename="<remove or specify output filename>" \
clean_groundtruth_text=True \
langid='en' \
batch_size=32 \
compute_timestamps=False \
compute_langs=False \
Expand Down Expand Up @@ -149,6 +157,12 @@ class TranscriptionConfig:
# Use this for model-specific changes before transcription
model_change: ModelChangeConfig = ModelChangeConfig()

# Config for word / character error rate calculation
calculate_wer: bool = True
clean_groundtruth_text: bool = False
langid: str = "en" # specify this for convert_num_to_words step in groundtruth cleaning
use_cer: bool = False


@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
Expand Down Expand Up @@ -322,7 +336,7 @@ def autocast():
transcriptions = transcriptions[0]

# write audio transcriptions
output_filename = write_transcription(
output_filename, pred_text_attr_name = write_transcription(
transcriptions,
cfg,
model_name,
Expand All @@ -332,6 +346,19 @@ def autocast():
)
logging.info(f"Finished writing predictions to {output_filename}!")

if cfg.calculate_wer:
output_manifest_w_wer, total_res, _ = cal_write_wer(
pred_manifest=output_filename,
pred_text_attr_name=pred_text_attr_name,
clean_groundtruth_text=cfg.clean_groundtruth_text,
langid=cfg.langid,
use_cer=cfg.use_cer,
output_filename=None,
)
if output_manifest_w_wer:
logging.info(f"Writing prediction and error rate of each sample to {output_manifest_w_wer}!")
logging.info(f"{total_res}")

return cfg


Expand Down
153 changes: 153 additions & 0 deletions nemo/collections/asr/parts/utils/eval_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from typing import Tuple

from nemo.collections.asr.metrics.wer import word_error_rate_detail
from nemo.utils import logging


def clean_label(_str: str, num_to_words: bool = True, langid="en") -> str:
"""
Remove unauthorized characters in a string, lower it and remove unneeded spaces
"""
replace_with_space = [char for char in '/?*\",.:=?_{|}~¨«·»¡¿„…‧‹›≪≫!:;ː→']
replace_with_blank = [char for char in '`¨´‘’“”`ʻ‘’“"‘”']
replace_with_apos = [char for char in '‘’ʻ‘’‘']
_str = _str.strip()
_str = _str.lower()
for i in replace_with_blank:
_str = _str.replace(i, "")
for i in replace_with_space:
_str = _str.replace(i, " ")
for i in replace_with_apos:
_str = _str.replace(i, "'")
if num_to_words:
if langid == "en":
_str = convert_num_to_words(_str, langid="en")
else:
logging.info(
"Currently support basic num_to_words in English only. Please use Text Normalization to convert other languages! Skipping!"
)

ret = " ".join(_str.split())
return ret


def convert_num_to_words(_str: str, langid: str = "en") -> str:
"""
Convert digits to corresponding words. Note this is a naive approach and could be replaced with text normalization.
"""
if langid == "en":
num_to_words = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
_str = _str.strip()
words = _str.split()
out_str = ""
num_word = []
for word in words:
if word.isdigit():
num = int(word)
while num:
digit = num % 10
digit_word = num_to_words[digit]
num_word.append(digit_word)
num = int(num / 10)
if not (num):
num_str = ""
num_word = num_word[::-1]
for ele in num_word:
num_str += ele + " "
out_str += num_str + " "
num_word.clear()
else:
out_str += word + " "
out_str = out_str.strip()
else:
raise ValueError(
"Currently support basic num_to_words in English only. Please use Text Normalization to convert other languages!"
)
return out_str


def cal_write_wer(
pred_manifest: str = None,
pred_text_attr_name: str = "pred_text",
clean_groundtruth_text: bool = False,
langid: str = 'en',
use_cer: bool = False,
output_filename: str = None,
) -> Tuple[str, dict, str]:
"""
Calculate wer, inserion, deletion and substitution rate based on groundtruth text and pred_text_attr_name (pred_text)
We use WER in function name as a convention, but Error Rate (ER) currently support Word Error Rate (WER) and Character Error Rate (CER)
"""
samples = []
hyps = []
refs = []
eval_metric = "cer" if use_cer else "wer"

with open(pred_manifest, 'r') as fp:
for line in fp:
sample = json.loads(line)

if 'text' not in sample:
logging.info(
"ground-truth text is not present in manifest! Cannot calculate Word Error Rate. Returning!"
)
return None, None, eval_metric

hyp = sample[pred_text_attr_name]
ref = sample['text']

if clean_groundtruth_text:
ref = clean_label(ref, langid=langid)

wer, tokens, ins_rate, del_rate, sub_rate = word_error_rate_detail(
hypotheses=[hyp], references=[ref], use_cer=use_cer
)
sample[eval_metric] = wer # evaluatin metric, could be word error rate of character error rate
sample['tokens'] = tokens # number of word/characters/tokens
sample['ins_rate'] = ins_rate # insertion error rate
sample['del_rate'] = del_rate # deletion error rate
sample['sub_rate'] = sub_rate # substitution error rate

samples.append(sample)
hyps.append(hyp)
refs.append(ref)

total_wer, total_tokens, total_ins_rate, total_del_rate, total_sub_rate = word_error_rate_detail(
hypotheses=hyps, references=refs, use_cer=use_cer
)

if not output_filename:
output_manifest_w_wer = pred_manifest
else:
output_manifest_w_wer = output_filename

with open(output_manifest_w_wer, 'w') as fout:
for sample in samples:
json.dump(sample, fout)
fout.write('\n')
fout.flush()

total_res = {
"samples": len(samples),
"tokens": total_tokens,
eval_metric: total_wer,
"ins_rate": total_ins_rate,
"del_rate": total_del_rate,
"sub_rate": total_sub_rate,
}
return output_manifest_w_wer, total_res, eval_metric
Loading

0 comments on commit f7989f7

Please sign in to comment.