diff --git a/tools/ctc_segmentation/scripts/prepare_data.py b/tools/ctc_segmentation/scripts/prepare_data.py index c6ea024273fb..476e719eb51b 100644 --- a/tools/ctc_segmentation/scripts/prepare_data.py +++ b/tools/ctc_segmentation/scripts/prepare_data.py @@ -26,6 +26,8 @@ from tqdm import tqdm from nemo.collections.asr.models import ASRModel +from nemo.collections.asr.models.ctc_models import EncDecCTCModel +from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel from nemo.utils import model_utils try: @@ -354,7 +356,19 @@ def _split(sentences, delimiter): asr_model = ASRModel.from_pretrained(model_name=args.model) # type: ASRModel model_name = args.model - vocabulary = asr_model.cfg.decoder.vocabulary + if not (isinstance(asr_model, EncDecCTCModel) or isinstance(asr_model, EncDecHybridRNNTCTCModel)): + raise NotImplementedError( + f"Model is not an instance of NeMo EncDecCTCModel or ENCDecHybridRNNTCTCModel." + " Currently only instances of these models are supported" + ) + + # get vocabulary list + if hasattr(asr_model, 'tokenizer'): # i.e. tokenization is BPE-based + vocabulary = asr_model.tokenizer.vocab + elif hasattr(asr_model.decoder, "vocabulary"): # i.e. tokenization is character-based + vocabulary = asr_model.cfg.decoder.vocabulary + else: + raise ValueError("Unexpected model type. Vocabulary list not found.") if os.path.isdir(args.in_text): text_files = glob(f"{args.in_text}/*.txt") diff --git a/tools/ctc_segmentation/scripts/run_ctc_segmentation.py b/tools/ctc_segmentation/scripts/run_ctc_segmentation.py index dddeb9a42dc2..c9d9ed2d8731 100644 --- a/tools/ctc_segmentation/scripts/run_ctc_segmentation.py +++ b/tools/ctc_segmentation/scripts/run_ctc_segmentation.py @@ -27,6 +27,8 @@ from utils import get_segments import nemo.collections.asr as nemo_asr +from nemo.collections.asr.models.ctc_models import EncDecCTCModel +from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel parser = argparse.ArgumentParser(description="CTC Segmentation") parser.add_argument("--output_dir", default="output", type=str, help="Path to output directory") @@ -72,18 +74,19 @@ logging.basicConfig(handlers=handlers, level=level) if os.path.exists(args.model): - asr_model = nemo_asr.models.EncDecCTCModel.restore_from(args.model) - elif args.model in nemo_asr.models.EncDecCTCModel.get_available_model_names(): - asr_model = nemo_asr.models.EncDecCTCModel.from_pretrained(args.model, strict=False) + asr_model = nemo_asr.models.ASRModel.restore_from(args.model) else: - try: - asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(args.model) - except: - raise ValueError( - f"Provide path to the pretrained checkpoint or choose from {nemo_asr.models.EncDecCTCModel.get_available_model_names()}" - ) + asr_model = nemo_asr.models.ASRModel.from_pretrained(args.model, strict=False) + + if not (isinstance(asr_model, EncDecCTCModel) or isinstance(asr_model, EncDecHybridRNNTCTCModel)): + raise NotImplementedError( + f"Model is not an instance of NeMo EncDecCTCModel or ENCDecHybridRNNTCTCModel." + " Currently only instances of these models are supported" + ) - bpe_model = isinstance(asr_model, nemo_asr.models.EncDecCTCModelBPE) + bpe_model = isinstance(asr_model, nemo_asr.models.EncDecCTCModelBPE) or isinstance( + asr_model, nemo_asr.models.EncDecHybridRNNTCTCBPEModel + ) # get tokenizer used during training, None for char based models if bpe_model: @@ -91,8 +94,18 @@ else: tokenizer = None + if isinstance(asr_model, EncDecHybridRNNTCTCModel): + asr_model.change_decoding_strategy(decoder_type="ctc") + # extract ASR vocabulary and add blank symbol - vocabulary = ["ε"] + list(asr_model.cfg.decoder.vocabulary) + if hasattr(asr_model, 'tokenizer'): # i.e. tokenization is BPE-based + vocabulary = asr_model.tokenizer.vocab + elif hasattr(asr_model.decoder, "vocabulary"): # i.e. tokenization is character-based + vocabulary = asr_model.cfg.decoder.vocabulary + else: + raise ValueError("Unexpected model type. Vocabulary list not found.") + + vocabulary = ["ε"] + list(vocabulary) logging.debug(f"ASR Model vocabulary: {vocabulary}") data = Path(args.data) @@ -136,9 +149,14 @@ logging.debug(f"len(signal): {len(signal)}, sr: {sample_rate}") logging.debug(f"Duration: {original_duration}s, file_name: {path_audio}") - log_probs = asr_model.transcribe(audio=[str(path_audio)], batch_size=1, return_hypotheses=True)[ + hypotheses = asr_model.transcribe([str(path_audio)], batch_size=1, return_hypotheses=True) + # if hypotheses form a tuple (from Hybrid model), extract just "best" hypothesis + if type(hypotheses) == tuple and len(hypotheses) == 2: + hypotheses = hypotheses[0] + log_probs = hypotheses[ 0 - ].alignments + ].alignments # note: "[0]" is for batch dimension unpacking (and here batch size=1) + # move blank values to the first column (ctc-package compatibility) blank_col = log_probs[:, -1].reshape((log_probs.shape[0], 1)) log_probs = np.concatenate((blank_col, log_probs[:, :-1]), axis=1)