From 8d78af35fcef1ffbeb1fb94f321488f37af5a79f Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 5 Jul 2023 11:02:34 -0700 Subject: [PATCH] Hybrid conformer export Signed-off-by: Boris Fomitchev --- nemo/collections/asr/models/asr_model.py | 9 +++++++-- .../collections/asr/models/hybrid_rnnt_ctc_models.py | 7 +++++++ nemo/collections/asr/models/rnnt_models.py | 12 ++++-------- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index 1e38b8b93062..991a204f6240 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -220,9 +220,9 @@ def forward_for_export( ret = self.output_module.forward_for_export(encoder_output=decoder_input) else: if cache_last_channel is None and cache_last_time is None: - ret = self.output_module(encoder_output=decoder_input) + ret = self.output_module(decoder_input) else: - ret = self.output_module(encoder_output=decoder_input) + ret = self.output_module(decoder_input) if cache_last_channel is None and cache_last_time is None: pass else: @@ -246,4 +246,9 @@ def set_export_config(self, **kwargs): self.encoder.export_cache_support = enable logging.info(f"Caching support enabled: {enable}") self.encoder.setup_streaming_params() + if 'decoder_type' in kwargs: + if hasattr(self, 'change_decoding_strategy'): + self.change_decoding_strategy(decoder_type=kwargs['decoder_type']) + else: + raise Exception("Model does not have decoder type option") super().set_export_config(**kwargs) diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index 5ca6124ecfd7..2b13f02ab657 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -645,6 +645,13 @@ def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): self.finalize_interctc_metrics(metrics, outputs, prefix="test_") return metrics + # EncDecRNNTModel is exported in 2 parts + def list_export_subnets(self): + if self.cur_decoder == 'rnnt': + return ['encoder', 'decoder_joint'] + else: + return ['encoder'] + @classmethod def list_available_models(cls) -> Optional[PretrainedModelInfo]: """ diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index 92bb04fd2a3e..1a3bd2b46cfc 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -28,7 +28,7 @@ from nemo.collections.asr.data.audio_to_text_dali import AudioToCharDALIDataset, DALIOutputs from nemo.collections.asr.losses.rnnt import RNNTLoss, resolve_rnnt_default_loss_name from nemo.collections.asr.metrics.rnnt_wer import RNNTWER, RNNTDecoding, RNNTDecodingConfig -from nemo.collections.asr.models.asr_model import ASRModel +from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel from nemo.collections.asr.modules.rnnt import RNNTDecoderJoint from nemo.collections.asr.parts.mixins import ASRModuleMixin from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType @@ -39,7 +39,7 @@ from nemo.utils import logging -class EncDecRNNTModel(ASRModel, ASRModuleMixin, Exportable): +class EncDecRNNTModel(ASRModel, ASRModuleMixin, ExportableEncDecModel): """Base class for encoder decoder RNNT-based models.""" def __init__(self, cfg: DictConfig, trainer: Trainer = None): @@ -475,7 +475,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): return dataset shuffle = config['shuffle'] - if isinstance(dataset, torch.utils.data.IterableDataset): + if config.get('is_tarred', False): shuffle = False if hasattr(dataset, 'collate_fn'): @@ -523,11 +523,7 @@ def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict # Need to set this because if using an IterableDataset, the length of the dataloader is the total number # of samples rather than the number of batches, and this messes up the tqdm progress bar. # So we set the number of steps manually (to the correct number) to fix this. - if ( - self._train_dl is not None - and hasattr(self._train_dl, 'dataset') - and isinstance(self._train_dl.dataset, torch.utils.data.IterableDataset) - ): + if 'is_tarred' in train_data_config and train_data_config['is_tarred']: # We also need to check if limit_train_batches is already set. # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0).