diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index fa821a3ca83ee..a40cbd9294ad6 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -483,7 +483,7 @@ def transcribe( """ if timestamps: raise NotImplementedError("Timestamps are not supported for this model yet.") - + if override_config is None: trcfg = MultiTaskTranscriptionConfig( batch_size=batch_size, @@ -501,7 +501,9 @@ def transcribe( f"but got {type(override_config)}" ) trcfg = override_config - import pdb; pdb.set_trace() + import pdb + + pdb.set_trace() return super().transcribe(audio=audio, override_config=trcfg) def _setup_dataloader_from_config(self, config: Optional[Dict]): diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index 7aa2ca36db3f2..51bb72f0729e9 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -155,7 +155,7 @@ def transcribe( """ if timestamps: logging.info("Timestamps requested, setting decoding timestamps to True") - return_hypotheses=True + return_hypotheses = True with open_dict(self.cfg.decoding): self.cfg.decoding.compute_timestamps = True self.cfg.decoding.preserve_alignments = True @@ -719,8 +719,12 @@ def _transcribe_output_processing(self, outputs, trcfg: TranscribeConfig) -> Gen del logits, logits_len if trcfg.timestamps: - current_hypotheses = process_timestamp_outputs(current_hypotheses, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']) - all_hyp = process_timestamp_outputs(all_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']) + current_hypotheses = process_timestamp_outputs( + current_hypotheses, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'] + ) + all_hyp = process_timestamp_outputs( + all_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'] + ) hypotheses = [] if all_hyp is None: diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py index 53fef2f3536d1..aaa01a3687c3e 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -415,7 +415,9 @@ def change_vocabulary( logging.info(f"Changed tokenizer of the CTC decoder to {self.ctc_decoder.vocabulary} vocabulary.") - def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type: str = None, verbose: bool = True): + def change_decoding_strategy( + self, decoding_cfg: DictConfig = None, decoder_type: str = None, verbose: bool = True + ): """ Changes decoding strategy used during RNNT decoding process. Args: @@ -468,7 +470,9 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type self.cur_decoder = "rnnt" if verbose: - logging.info(f"Changed decoding strategy of the RNNT decoder to \n{OmegaConf.to_yaml(self.cfg.decoding)}") + logging.info( + f"Changed decoding strategy of the RNNT decoder to \n{OmegaConf.to_yaml(self.cfg.decoding)}" + ) elif decoder_type == 'ctc': if not hasattr(self, 'ctc_decoding'): @@ -501,7 +505,7 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type self.cur_decoder = "ctc" if verbose: logging.info( - f"Changed decoding strategy of the CTC decoder to \n{OmegaConf.to_yaml(self.cfg.aux_ctc.decoding)}" + f"Changed decoding strategy of the CTC decoder to \n{OmegaConf.to_yaml(self.cfg.aux_ctc.decoding)}" ) else: raise ValueError(f"decoder_type={decoder_type} is not supported. Supported values: [ctc,rnnt]") diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index 37cda1c64bb36..54ac35d1bf835 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -142,12 +142,12 @@ def transcribe( if timestamps: logging.info("Timestamps requested, setting decoding timestamps to True") - return_hypotheses=True + return_hypotheses = True with open_dict(decoding_cfg): decoding_cfg.decoding.compute_timestamps = True decoding_cfg.decoding.preserve_alignments = True self.change_decoding_strategy(decoding_cfg, self.cur_decoder, verbose=False) - else: # This is done to ensure the timestamps are not computed if not requested + else: # This is done to ensure the timestamps are not computed if not requested with open_dict(decoding_cfg): decoding_cfg.compute_timestamps = decoding_cfg.get('compute_timestamps', False) decoding_cfg.preserve_alignments = decoding_cfg.get('preserve_alignments', False) @@ -221,8 +221,12 @@ def _transcribe_output_processing( # logits_list.append(logit[:elen]) if trcfg.timestamps: - hypotheses = process_timestamp_outputs(best_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']) - all_hyp = process_timestamp_outputs(all_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']) + hypotheses = process_timestamp_outputs( + best_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'] + ) + all_hyp = process_timestamp_outputs( + all_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'] + ) del logits, encoded_len @@ -318,7 +322,9 @@ def change_vocabulary( logging.info(f"Changed the tokenizer of the CTC decoder to {self.ctc_decoder.vocabulary} vocabulary.") - def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type: str = None, verbose: bool = True): + def change_decoding_strategy( + self, decoding_cfg: DictConfig = None, decoder_type: str = None, verbose: bool = True + ): """ Changes decoding strategy used during RNNT decoding process. diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index a1fe6874ca41a..f710d5939531d 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -281,7 +281,7 @@ def transcribe( """ if timestamps: logging.info("Timestamps requested, setting decoding timestamps to True") - return_hypotheses=True + return_hypotheses = True with open_dict(self.cfg.decoding): self.cfg.decoding.compute_timestamps = True self.cfg.decoding.preserve_alignments = True @@ -941,8 +941,12 @@ def _transcribe_output_processing( del encoded, encoded_len if trcfg.timestamps: - hypotheses = process_timestamp_outputs(best_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']) - all_hyp = process_timestamp_outputs(all_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']) + hypotheses = process_timestamp_outputs( + best_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'] + ) + all_hyp = process_timestamp_outputs( + all_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'] + ) hypotheses = [] all_hypotheses = [] @@ -955,7 +959,6 @@ def _transcribe_output_processing( return (hypotheses, all_hypotheses) - def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': """ Setup function for a temporary data loader which wraps the provided audio file. diff --git a/nemo/collections/asr/modules/conv_asr.py b/nemo/collections/asr/modules/conv_asr.py index 8193c46953255..b9edb19b6a4e0 100644 --- a/nemo/collections/asr/modules/conv_asr.py +++ b/nemo/collections/asr/modules/conv_asr.py @@ -133,7 +133,7 @@ def __init__( residual_panes = [] encoder_layers = [] self.dense_residual = False - self._subsampling_factor = 1 + self._subsampling_factor = 1 for layer_idx, lcfg in enumerate(jasper): dense_res = [] if lcfg.get('residual_dense', False): @@ -182,7 +182,9 @@ def __init__( ) ) feat_in = lcfg['filters'] - self._subsampling_factor *= int(lcfg['stride'][0]) if isinstance(lcfg['stride'], List) else int(lcfg['stride']) + self._subsampling_factor *= ( + int(lcfg['stride'][0]) if isinstance(lcfg['stride'], List) else int(lcfg['stride']) + ) self._feat_out = feat_in @@ -235,6 +237,7 @@ def update_max_sequence_length(self, seq_length: int, device): def subsampling_factor(self) -> int: return self._subsampling_factor + class ParallelConvASREncoder(NeuralModule, Exportable): """ Convolutional encoder for ASR models with parallel blocks. CarneliNet can be implemented with this class. diff --git a/nemo/collections/asr/parts/mixins/transcription.py b/nemo/collections/asr/parts/mixins/transcription.py index 598c661d12530..b2fdcf990caf7 100644 --- a/nemo/collections/asr/parts/mixins/transcription.py +++ b/nemo/collections/asr/parts/mixins/transcription.py @@ -62,8 +62,8 @@ class TranscribeConfig: num_workers: Optional[int] = None channel_selector: ChannelSelectorType = None augmentor: Optional[DictConfig] = None - timestamps: bool = False # returns timestamps for each word and segments if model supports punctuations - verbose: bool = True + timestamps: bool = False # returns timestamps for each word and segments if model supports punctuations + verbose: bool = True # Utility partial_hypothesis: Optional[List[Any]] = None diff --git a/nemo/collections/asr/parts/utils/transcribe_utils.py b/nemo/collections/asr/parts/utils/transcribe_utils.py index 0449b0e9362dd..63de4d244469a 100644 --- a/nemo/collections/asr/parts/utils/transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/transcribe_utils.py @@ -586,7 +586,7 @@ def process_timestamp_outputs(outputs, subsampling_factor: int = 1, window_strid if isinstance(outputs, rnnt_utils.Hypothesis): outputs = [outputs] - + if not isinstance(outputs[0], rnnt_utils.Hypothesis): raise ValueError(f"Expected Hypothesis object, got {type(outputs[0])}") @@ -604,19 +604,23 @@ def process_timestamp(timestamp, subsampling_factor, window_stride): val['end'] = end return timestamp - + for idx, hyp in enumerate(outputs): if not hasattr(hyp, 'timestep'): - raise ValueError(f"Expected Hypothesis object to have 'timestep' attribute, when compute_timestamps is enabled but got {hyp}") + raise ValueError( + f"Expected Hypothesis object to have 'timestep' attribute, when compute_timestamps is enabled but got {hyp}" + ) timestep = hyp.timestep if 'word' in timestep: outputs[idx].timestep['word'] = process_timestamp(timestep['word'], subsampling_factor, window_stride) if 'char' in timestep: outputs[idx].timestep['char'] = process_timestamp(timestep['char'], subsampling_factor, window_stride) if 'segment' in timestep: - outputs[idx].timestep['segment'] = process_timestamp(timestep['segment'], subsampling_factor, window_stride) + outputs[idx].timestep['segment'] = process_timestamp( + timestep['segment'], subsampling_factor, window_stride + ) return outputs - + class PunctuationCapitalization: def __init__(self, punctuation_marks: str):