Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch transcribe and support offline transcribe for hybrid model #6550

Merged
merged 4 commits into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf
from omegaconf import OmegaConf, open_dict

from nemo.collections.asr.metrics.rnnt_wer import RNNTDecodingConfig
from nemo.collections.asr.metrics.wer import CTCDecodingConfig
from nemo.collections.asr.models.ctc_models import EncDecCTCModel
from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel
from nemo.collections.asr.modules.conformer_encoder import ConformerChangeConfig
from nemo.collections.asr.parts.utils.transcribe_utils import (
compute_output_filename,
Expand Down Expand Up @@ -154,6 +154,9 @@ class TranscriptionConfig:
def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

for key in cfg:
VahidooX marked this conversation as resolved.
Show resolved Hide resolved
cfg[key] = None if cfg[key] == 'None' else cfg[key]

if is_dataclass(cfg):
cfg = OmegaConf.structured(cfg)

Expand Down Expand Up @@ -223,7 +226,6 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
decoding_cfg.preserve_alignments = cfg.compute_timestamps
if 'compute_langs' in decoding_cfg:
decoding_cfg.compute_langs = cfg.compute_langs

asr_model.change_decoding_strategy(decoding_cfg, decoder_type=cfg.decoder_type)

# Check if ctc or rnnt model
Expand All @@ -243,6 +245,15 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:

asr_model.change_decoding_strategy(cfg.ctc_decoding)

# Setup decoding config based on model type and decoder_type
with open_dict(cfg):
if isinstance(asr_model, EncDecCTCModel) or (
isinstance(asr_model, EncDecHybridRNNTCTCModel) and cfg.decoder_type == "ctc"
):
cfg.decoding = cfg.ctc_decoding
else:
cfg.decoding = cfg.rnnt_decoding

# prepare audio filepaths and decide wether it's partical audio
filepaths, partial_audio = prepare_audio_data(cfg)

Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,14 +289,14 @@ def write_transcription(

if isinstance(transcriptions[0], rnnt_utils.Hypothesis): # List[rnnt_utils.Hypothesis]
best_hyps = transcriptions
assert cfg.ctc_decoding.beam.return_best_hypothesis, "Works only with return_best_hypothesis=true"
assert cfg.decoding.beam.return_best_hypothesis, "Works only with return_best_hypothesis=true"
elif isinstance(transcriptions[0], list) and isinstance(
transcriptions[0][0], rnnt_utils.Hypothesis
): # List[List[rnnt_utils.Hypothesis]] NBestHypothesis
best_hyps, beams = [], []
for hyps in transcriptions:
best_hyps.append(hyps[0])
if not cfg.ctc_decoding.beam.return_best_hypothesis:
if not cfg.decoding.beam.return_best_hypothesis:
beam = []
for hyp in hyps:
beam.append((hyp.text, hyp.score))
Expand Down
2 changes: 1 addition & 1 deletion tools/asr_evaluator/conf/eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ engine:
chunk_len_in_secs: 1.6 #null # Need to specify if use buffered inference (default for offline_by_chunked is 20)
total_buffer_in_secs: 4 #null # Need to specify if use buffered inference (default for offline_by_chunked is 22)
model_stride: 4 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models

decoder_type: null # Used for hybrid CTC RNNT model only. Specify decoder_type *ctc* or *rnnt* for hybrid CTC RNNT model.
test_ds:
manifest_filepath: null
sample_rate: 16000
Expand Down
3 changes: 2 additions & 1 deletion tools/asr_evaluator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def run_offline_inference(cfg: DictConfig) -> DictConfig:
f"output_filename={cfg.output_filename} "
f"batch_size={cfg.test_ds.batch_size} "
f"random_seed={cfg.random_seed} "
f"eval_config_yaml={f.name} ",
f"eval_config_yaml={f.name} "
f"decoder_type={cfg.inference.decoder_type} ",
shell=True,
check=True,
)
Expand Down