Skip to content

Commit

Permalink
Enable using hybrid asr models in CTC Segmentation tool (#8828)
Browse files Browse the repository at this point in the history
* enable using hybrid asr models in ctc segmentation tool

Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
erastorgueva-nv and pre-commit-ci[bot] authored Apr 20, 2024
1 parent c9c8408 commit 206e84a
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 14 deletions.
16 changes: 15 additions & 1 deletion tools/ctc_segmentation/scripts/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from tqdm import tqdm

from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.models.ctc_models import EncDecCTCModel
from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel
from nemo.utils import model_utils

try:
Expand Down Expand Up @@ -354,7 +356,19 @@ def _split(sentences, delimiter):
asr_model = ASRModel.from_pretrained(model_name=args.model) # type: ASRModel
model_name = args.model

vocabulary = asr_model.cfg.decoder.vocabulary
if not (isinstance(asr_model, EncDecCTCModel) or isinstance(asr_model, EncDecHybridRNNTCTCModel)):
raise NotImplementedError(
f"Model is not an instance of NeMo EncDecCTCModel or ENCDecHybridRNNTCTCModel."
" Currently only instances of these models are supported"
)

# get vocabulary list
if hasattr(asr_model, 'tokenizer'): # i.e. tokenization is BPE-based
vocabulary = asr_model.tokenizer.vocab
elif hasattr(asr_model.decoder, "vocabulary"): # i.e. tokenization is character-based
vocabulary = asr_model.cfg.decoder.vocabulary
else:
raise ValueError("Unexpected model type. Vocabulary list not found.")

if os.path.isdir(args.in_text):
text_files = glob(f"{args.in_text}/*.txt")
Expand Down
44 changes: 31 additions & 13 deletions tools/ctc_segmentation/scripts/run_ctc_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from utils import get_segments

import nemo.collections.asr as nemo_asr
from nemo.collections.asr.models.ctc_models import EncDecCTCModel
from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel

parser = argparse.ArgumentParser(description="CTC Segmentation")
parser.add_argument("--output_dir", default="output", type=str, help="Path to output directory")
Expand Down Expand Up @@ -72,27 +74,38 @@
logging.basicConfig(handlers=handlers, level=level)

if os.path.exists(args.model):
asr_model = nemo_asr.models.EncDecCTCModel.restore_from(args.model)
elif args.model in nemo_asr.models.EncDecCTCModel.get_available_model_names():
asr_model = nemo_asr.models.EncDecCTCModel.from_pretrained(args.model, strict=False)
asr_model = nemo_asr.models.ASRModel.restore_from(args.model)
else:
try:
asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(args.model)
except:
raise ValueError(
f"Provide path to the pretrained checkpoint or choose from {nemo_asr.models.EncDecCTCModel.get_available_model_names()}"
)
asr_model = nemo_asr.models.ASRModel.from_pretrained(args.model, strict=False)

if not (isinstance(asr_model, EncDecCTCModel) or isinstance(asr_model, EncDecHybridRNNTCTCModel)):
raise NotImplementedError(
f"Model is not an instance of NeMo EncDecCTCModel or ENCDecHybridRNNTCTCModel."
" Currently only instances of these models are supported"
)

bpe_model = isinstance(asr_model, nemo_asr.models.EncDecCTCModelBPE)
bpe_model = isinstance(asr_model, nemo_asr.models.EncDecCTCModelBPE) or isinstance(
asr_model, nemo_asr.models.EncDecHybridRNNTCTCBPEModel
)

# get tokenizer used during training, None for char based models
if bpe_model:
tokenizer = asr_model.tokenizer
else:
tokenizer = None

if isinstance(asr_model, EncDecHybridRNNTCTCModel):
asr_model.change_decoding_strategy(decoder_type="ctc")

# extract ASR vocabulary and add blank symbol
vocabulary = ["ε"] + list(asr_model.cfg.decoder.vocabulary)
if hasattr(asr_model, 'tokenizer'): # i.e. tokenization is BPE-based
vocabulary = asr_model.tokenizer.vocab
elif hasattr(asr_model.decoder, "vocabulary"): # i.e. tokenization is character-based
vocabulary = asr_model.cfg.decoder.vocabulary
else:
raise ValueError("Unexpected model type. Vocabulary list not found.")

vocabulary = ["ε"] + list(vocabulary)
logging.debug(f"ASR Model vocabulary: {vocabulary}")

data = Path(args.data)
Expand Down Expand Up @@ -136,9 +149,14 @@
logging.debug(f"len(signal): {len(signal)}, sr: {sample_rate}")
logging.debug(f"Duration: {original_duration}s, file_name: {path_audio}")

log_probs = asr_model.transcribe(audio=[str(path_audio)], batch_size=1, return_hypotheses=True)[
hypotheses = asr_model.transcribe([str(path_audio)], batch_size=1, return_hypotheses=True)
# if hypotheses form a tuple (from Hybrid model), extract just "best" hypothesis
if type(hypotheses) == tuple and len(hypotheses) == 2:
hypotheses = hypotheses[0]
log_probs = hypotheses[
0
].alignments
].alignments # note: "[0]" is for batch dimension unpacking (and here batch size=1)

# move blank values to the first column (ctc-package compatibility)
blank_col = log_probs[:, -1].reshape((log_probs.shape[0], 1))
log_probs = np.concatenate((blank_col, log_probs[:, :-1]), axis=1)
Expand Down

0 comments on commit 206e84a

Please sign in to comment.