diff --git a/nemo/collections/asr/models/audio_to_audio_model.py b/nemo/collections/asr/models/audio_to_audio_model.py index b48cd0c14e62..21860cf8ab56 100644 --- a/nemo/collections/asr/models/audio_to_audio_model.py +++ b/nemo/collections/asr/models/audio_to_audio_model.py @@ -17,7 +17,7 @@ import hydra import torch -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from pytorch_lightning import Trainer from nemo.collections.asr.metrics.audio import AudioMetricWrapper @@ -67,12 +67,15 @@ def _setup_metrics(self, tag: str = 'val'): logging.debug('Found %d metrics for tag %s, not necesary to initialize again', num_dataloaders, tag) return - if 'metrics' not in self._cfg or tag not in self._cfg['metrics']: + if self.cfg.get('metrics') is None: # Metrics are not available in the configuration, nothing to do - logging.debug('No metrics configured for %s in model.metrics.%s', tag, tag) + logging.debug('No metrics configured in model.metrics') return - metrics_cfg = self._cfg['metrics'][tag] + if (metrics_cfg := self.cfg['metrics'].get(tag)) is None: + # Metrics configuration is not available in the configuration, nothing to do + logging.debug('No metrics configured for %s in model.metrics', tag) + return if 'loss' in metrics_cfg: raise ValueError( @@ -86,16 +89,19 @@ def _setup_metrics(self, tag: str = 'val'): # Setup metrics for each dataloader self.metrics[tag] = torch.nn.ModuleList() for dataloader_idx in range(num_dataloaders): - metrics_dataloader_idx = torch.nn.ModuleDict( - { - name: AudioMetricWrapper( - metric=hydra.utils.instantiate(cfg), - channel=cfg.get('channel'), - metric_using_batch_averaging=cfg.get('metric_using_batch_averaging'), - ) - for name, cfg in metrics_cfg.items() - } - ) + metrics_dataloader_idx = {} + for name, cfg in metrics_cfg.items(): + logging.debug('Initialize %s for dataloader_idx %s', name, dataloader_idx) + cfg_dict = OmegaConf.to_container(cfg) + cfg_channel = cfg_dict.pop('channel', None) + cfg_batch_averaging = cfg_dict.pop('metric_using_batch_averaging', None) + metrics_dataloader_idx[name] = AudioMetricWrapper( + metric=hydra.utils.instantiate(cfg_dict), + channel=cfg_channel, + metric_using_batch_averaging=cfg_batch_averaging, + ) + + metrics_dataloader_idx = torch.nn.ModuleDict(metrics_dataloader_idx) self.metrics[tag].append(metrics_dataloader_idx.to(self.device)) logging.info( diff --git a/tutorials/audio_tasks/speech_enhancement/Speech_Enhancement_with_NeMo.ipynb b/tutorials/audio_tasks/speech_enhancement/Speech_Enhancement_with_NeMo.ipynb index 7dd1e796fa9a..786c41303a40 100644 --- a/tutorials/audio_tasks/speech_enhancement/Speech_Enhancement_with_NeMo.ipynb +++ b/tutorials/audio_tasks/speech_enhancement/Speech_Enhancement_with_NeMo.ipynb @@ -102,11 +102,6 @@ "from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest\n", "\n", "\n", - "# Used to download data processing scripts\n", - "USER = 'anteju' # TODO: change to 'NVIDIA'\n", - "BRANCH = 'dev/se-tutorial' # TODO: change to 'r1.21.0'\n", - "\n", - "\n", "# Utility functions for displaying signals and metrics\n", "def show_signal(signal: np.ndarray, sample_rate: int = 16000, tag: str = 'Signal'):\n", " \"\"\"Show the time-domain signal and its spectrogram.\n", @@ -607,7 +602,7 @@ " '_target_': 'torchmetrics.audio.SignalDistortionRatio',\n", " }\n", "})\n", - "config.model.metrics.validation = metrics\n", + "config.model.metrics.val = metrics\n", "config.model.metrics.test = metrics\n", "\n", "print(\"Metrics config:\")\n", @@ -1112,7 +1107,7 @@ " 'channel': 1,\n", " },\n", "})\n", - "config_dual_output.model.metrics.validation = metrics\n", + "config_dual_output.model.metrics.val = metrics\n", "config_dual_output.model.metrics.test = metrics" ] },