Skip to content

Commit

Permalink
Hybrid conformer export
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
  • Loading branch information
borisfom committed Jul 5, 2023
1 parent 2b57d47 commit 8d78af3
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
9 changes: 7 additions & 2 deletions nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ def forward_for_export(
ret = self.output_module.forward_for_export(encoder_output=decoder_input)
else:
if cache_last_channel is None and cache_last_time is None:
ret = self.output_module(encoder_output=decoder_input)
ret = self.output_module(decoder_input)
else:
ret = self.output_module(encoder_output=decoder_input)
ret = self.output_module(decoder_input)
if cache_last_channel is None and cache_last_time is None:
pass
else:
Expand All @@ -246,4 +246,9 @@ def set_export_config(self, **kwargs):
self.encoder.export_cache_support = enable
logging.info(f"Caching support enabled: {enable}")
self.encoder.setup_streaming_params()
if 'decoder_type' in kwargs:
if hasattr(self, 'change_decoding_strategy'):
self.change_decoding_strategy(decoder_type=kwargs['decoder_type'])
else:
raise Exception("Model does not have decoder type option")
super().set_export_config(**kwargs)
7 changes: 7 additions & 0 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,13 @@ def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
self.finalize_interctc_metrics(metrics, outputs, prefix="test_")
return metrics

# EncDecRNNTModel is exported in 2 parts
def list_export_subnets(self):
if self.cur_decoder == 'rnnt':
return ['encoder', 'decoder_joint']
else:
return ['encoder']

@classmethod
def list_available_models(cls) -> Optional[PretrainedModelInfo]:
"""
Expand Down
12 changes: 4 additions & 8 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from nemo.collections.asr.data.audio_to_text_dali import AudioToCharDALIDataset, DALIOutputs
from nemo.collections.asr.losses.rnnt import RNNTLoss, resolve_rnnt_default_loss_name
from nemo.collections.asr.metrics.rnnt_wer import RNNTWER, RNNTDecoding, RNNTDecodingConfig
from nemo.collections.asr.models.asr_model import ASRModel
from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel
from nemo.collections.asr.modules.rnnt import RNNTDecoderJoint
from nemo.collections.asr.parts.mixins import ASRModuleMixin
from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
Expand All @@ -39,7 +39,7 @@
from nemo.utils import logging


class EncDecRNNTModel(ASRModel, ASRModuleMixin, Exportable):
class EncDecRNNTModel(ASRModel, ASRModuleMixin, ExportableEncDecModel):
"""Base class for encoder decoder RNNT-based models."""

def __init__(self, cfg: DictConfig, trainer: Trainer = None):
Expand Down Expand Up @@ -475,7 +475,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]):
return dataset

shuffle = config['shuffle']
if isinstance(dataset, torch.utils.data.IterableDataset):
if config.get('is_tarred', False):
shuffle = False

if hasattr(dataset, 'collate_fn'):
Expand Down Expand Up @@ -523,11 +523,7 @@ def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict
# Need to set this because if using an IterableDataset, the length of the dataloader is the total number
# of samples rather than the number of batches, and this messes up the tqdm progress bar.
# So we set the number of steps manually (to the correct number) to fix this.
if (
self._train_dl is not None
and hasattr(self._train_dl, 'dataset')
and isinstance(self._train_dl.dataset, torch.utils.data.IterableDataset)
):
if 'is_tarred' in train_data_config and train_data_config['is_tarred']:
# We also need to check if limit_train_batches is already set.
# If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches,
# and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0).
Expand Down

0 comments on commit 8d78af3

Please sign in to comment.