From 1cb25f83a2a287ba2f208c9b0e78e7cac0e2460c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ante=20Jukic=CC=81?= Date: Tue, 8 Nov 2022 09:37:54 -0800 Subject: [PATCH] Adding a mask estimator which can process an arbitrary number of channels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ante Jukić --- .../conf/beamforming_flex_channels.yaml | 146 ++++ examples/audio_tasks/process_audio.py | 9 + nemo/collections/asr/data/audio_to_audio.py | 2 +- nemo/collections/asr/losses/audio_losses.py | 8 +- .../asr/models/audio_to_audio_model.py | 44 +- .../asr/models/enhancement_models.py | 35 +- nemo/collections/asr/modules/__init__.py | 7 +- nemo/collections/asr/modules/audio_modules.py | 613 ++++++++++++-- .../parts/submodules/multichannel_modules.py | 778 ++++++++++++++++++ .../asr/parts/utils/audio_utils.py | 6 +- requirements/requirements_asr.txt | 2 +- .../test_asr_part_submodules_multichannel.py | 157 ++++ 12 files changed, 1698 insertions(+), 109 deletions(-) create mode 100644 examples/audio_tasks/conf/beamforming_flex_channels.yaml create mode 100644 nemo/collections/asr/parts/submodules/multichannel_modules.py create mode 100644 tests/collections/asr/test_asr_part_submodules_multichannel.py diff --git a/examples/audio_tasks/conf/beamforming_flex_channels.yaml b/examples/audio_tasks/conf/beamforming_flex_channels.yaml new file mode 100644 index 0000000000000..29fc87acf93d5 --- /dev/null +++ b/examples/audio_tasks/conf/beamforming_flex_channels.yaml @@ -0,0 +1,146 @@ +# This configuration contains the exemplary values for training a multichannel speech enhancement model with a mask-based beamformer. +# +name: beamforming_flex_channels + +model: + sample_rate: 16000 + skip_nan_grad: false + num_outputs: 1 + + train_ds: + manifest_filepath: ??? + input_key: audio_filepath # key of the input signal path in the manifest + input_channel_selector: null # load all channels from the input file + target_key: target_anechoic_filepath # key of the target signal path in the manifest + target_channel_selector: 0 # load only the first channel from the target file + audio_duration: 4.0 # in seconds, audio segment duration for training + random_offset: true # if the file is longer than audio_duration, use random offset to select a subsegment + min_duration: ${model.train_ds.audio_duration} + batch_size: 16 # batch size may be increased based on the available memory + shuffle: true + num_workers: 16 + pin_memory: true + + validation_ds: + manifest_filepath: ??? + input_key: audio_filepath # key of the input signal path in the manifest + input_channel_selector: null # load all channels from the input file + target_key: target_anechoic_filepath # key of the target signal path in the manifest + target_channel_selector: 0 # load only the first channel from the target file + batch_size: 8 + shuffle: false + num_workers: 8 + pin_memory: true + + channel_augment: + _target_: nemo.collections.asr.parts.submodules.multichannel_modules.ChannelAugment + num_channels_min: 2 # minimal number of channels selected for each batch + num_channels_max: null # max number of channels is determined by the batch size + permute_channels: true + + encoder: + _target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram + fft_length: 512 # Length of the window and FFT for calculating spectrogram + hop_length: 256 # Hop length for calculating spectrogram + + decoder: + _target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio + fft_length: ${model.encoder.fft_length} + hop_length: ${model.encoder.hop_length} + + mask_estimator: + _target_: nemo.collections.asr.modules.audio_modules.MaskEstimatorFlexChannels + num_outputs: ${model.num_outputs} # number of output masks + num_subbands: 257 # number of subbands for the input spectrogram + num_blocks: 5 # number of blocks in the model + channel_reduction_position: 3 # 0-indexed, apply channel reduction before this block + channel_reduction_type: average # channel-wise reduction + channel_block_type: transform_average_concatenate # channel block + temporal_block_type: conformer_encoder # temporal block + temporal_block_num_layers: 5 # number of layers for the temporal block + temporal_block_num_heads: 4 # number of heads for the temporal block + temporal_block_dimension: 128 # the hidden size of the temporal block + mag_reduction: null # channel-wise reduction of magnitude + mag_normalization: mean_var # normalization using mean and variance + use_ipd: true # use inter-channel phase difference + ipd_normalization: mean # mean normalization + + mask_processor: + # Mask-based multi-channel processor + _target_: nemo.collections.asr.modules.audio_modules.MaskBasedBeamformer + filter_type: pmwf # parametric multichannel wiener filter + filter_beta: 0.0 # mvdr + filter_rank: one + ref_channel: max_snr # select reference channel by maximizing estimated SNR + ref_hard: 1 # a one-hot reference. If false, a soft estimate across channels is used. + ref_hard_use_grad: false # use straight-through gradient when using hard reference + ref_subband_weighting: false # use subband weighting for reference estimation + num_subbands: ${model.mask_estimator.num_subbands} + + loss: + _target_: nemo.collections.asr.losses.SDRLoss + convolution_invariant: true # convolution-invariant loss + sdr_max: 30 # soft threshold for SDR + + metrics: + val: + sdr_0: + _target_: torchmetrics.audio.SignalDistortionRatio + channel: 0 # evaluate only on channel 0, if there are multiple outputs + + optim: + name: adamw + lr: 1e-4 + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + # scheduler setup + sched: + name: CosineAnnealing + # scheduler config override + warmup_steps: 10000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: null + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 25 # Interval of logging. + enable_progress_bar: true + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_loss" + mode: "min" + save_top_k: 5 + always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints + + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.pyth + # you need to set these two to true to continue the training + resume_if_exists: false + resume_ignore_no_checkpoint: false + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/audio_tasks/process_audio.py b/examples/audio_tasks/process_audio.py index 20650d8a8c3c0..e73831fe7a5f5 100644 --- a/examples/audio_tasks/process_audio.py +++ b/examples/audio_tasks/process_audio.py @@ -37,6 +37,7 @@ pretrained_name: name of a pretrained AudioToAudioModel model (from NGC registry) audio_dir: path to directory with audio files dataset_manifest: path to dataset JSON manifest file (in NeMo format) + max_utts: maximum number of utterances to process input_channel_selector: list of channels to take from audio files, defaults to `None` and takes all available channels input_key: key for audio filepath in the manifest file, defaults to `audio_filepath` @@ -80,6 +81,7 @@ class ProcessConfig: pretrained_name: Optional[str] = None # Name of a pretrained model audio_dir: Optional[str] = None # Path to a directory which contains audio files dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest + max_utts: Optional[int] = None # max number of utterances to process # Audio configs input_channel_selector: Optional[List] = None # Union types not supported Optional[Union[List, int]] @@ -171,6 +173,10 @@ def main(cfg: ProcessConfig) -> ProcessConfig: audio_file = manifest_dir / audio_file filepaths.append(str(audio_file.absolute())) + if cfg.max_utts is not None: + # Limit the number of utterances to process + filepaths = filepaths[: cfg.max_utts] + logging.info(f"\nProcessing {len(filepaths)} files...\n") # setup AMP (optional) @@ -225,6 +231,9 @@ def autocast(): item = json.loads(line) item['processed_audio_filepath'] = paths2processed_files[idx] f.write(json.dumps(item) + "\n") + + if cfg.max_utts is not None and idx >= cfg.max_utts - 1: + break else: for idx, processed_file in enumerate(paths2processed_files): item = {'processed_audio_filepath': processed_file} diff --git a/nemo/collections/asr/data/audio_to_audio.py b/nemo/collections/asr/data/audio_to_audio.py index 9f9eda7c865ae..a3c6dd0cc1b3f 100644 --- a/nemo/collections/asr/data/audio_to_audio.py +++ b/nemo/collections/asr/data/audio_to_audio.py @@ -636,7 +636,7 @@ def get_duration(audio_files: List[str]) -> List[float]: Returns: List of durations in seconds. """ - duration = [librosa.get_duration(filename=f) for f in flatten(audio_files)] + duration = [librosa.get_duration(path=f) for f in flatten(audio_files)] return duration def load_embedding(self, example: collections.Audio.OUTPUT_TYPE) -> Dict[str, torch.Tensor]: diff --git a/nemo/collections/asr/losses/audio_losses.py b/nemo/collections/asr/losses/audio_losses.py index 34c73a23d7b85..62ce4a9f7edde 100644 --- a/nemo/collections/asr/losses/audio_losses.py +++ b/nemo/collections/asr/losses/audio_losses.py @@ -121,8 +121,8 @@ def convolution_invariant_target( input_length: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, filter_length: int = 512, - diag_reg: float = 1e-8, - eps: float = 1e-10, + diag_reg: float = 1e-6, + eps: float = 1e-8, ) -> torch.Tensor: """Calculate optimal convolution-invariant target for a given estimate. Assumes time dimension is the last dimension in the array. @@ -222,7 +222,7 @@ def calculate_sdr_batch( convolution_filter_length: Optional[int] = 512, remove_mean: bool = True, sdr_max: Optional[float] = None, - eps: float = 1e-10, + eps: float = 1e-8, ) -> torch.Tensor: """Calculate signal-to-distortion ratio per channel. @@ -310,7 +310,7 @@ def __init__( convolution_filter_length: Optional[int] = 512, remove_mean: bool = True, sdr_max: Optional[float] = None, - eps: float = 1e-10, + eps: float = 1e-8, ): super().__init__() diff --git a/nemo/collections/asr/models/audio_to_audio_model.py b/nemo/collections/asr/models/audio_to_audio_model.py index b48cd0c14e625..94b299f6b0529 100644 --- a/nemo/collections/asr/models/audio_to_audio_model.py +++ b/nemo/collections/asr/models/audio_to_audio_model.py @@ -17,7 +17,7 @@ import hydra import torch -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from pytorch_lightning import Trainer from nemo.collections.asr.metrics.audio import AudioMetricWrapper @@ -86,16 +86,19 @@ def _setup_metrics(self, tag: str = 'val'): # Setup metrics for each dataloader self.metrics[tag] = torch.nn.ModuleList() for dataloader_idx in range(num_dataloaders): - metrics_dataloader_idx = torch.nn.ModuleDict( - { - name: AudioMetricWrapper( - metric=hydra.utils.instantiate(cfg), - channel=cfg.get('channel'), - metric_using_batch_averaging=cfg.get('metric_using_batch_averaging'), - ) - for name, cfg in metrics_cfg.items() - } - ) + metrics_dataloader_idx = {} + for name, cfg in metrics_cfg.items(): + logging.debug('Initialize %s for dataloader_idx %s', name, dataloader_idx) + cfg_dict = OmegaConf.to_container(cfg) + cfg_channel = cfg_dict.pop('channel', None) + cfg_batch_averaging = cfg_dict.pop('metric_using_batch_averaging', None) + metrics_dataloader_idx[name] = AudioMetricWrapper( + metric=hydra.utils.instantiate(cfg_dict), + channel=cfg_channel, + metric_using_batch_averaging=cfg_batch_averaging, + ) + + metrics_dataloader_idx = torch.nn.ModuleDict(metrics_dataloader_idx) self.metrics[tag].append(metrics_dataloader_idx.to(self.device)) logging.info( @@ -115,15 +118,24 @@ def on_test_start(self): return super().on_test_start() def validation_step(self, batch, batch_idx, dataloader_idx: int = 0): - return self.evaluation_step(batch, batch_idx, dataloader_idx, 'val') + output_dict = self.evaluation_step(batch, batch_idx, dataloader_idx, 'val') + if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: + self.validation_step_outputs[dataloader_idx].append(output_dict) + else: + self.validation_step_outputs.append(output_dict) + return output_dict def test_step(self, batch, batch_idx, dataloader_idx=0): - return self.evaluation_step(batch, batch_idx, dataloader_idx, 'test') + output_dict = self.evaluation_step(batch, batch_idx, dataloader_idx, 'test') + if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1: + self.test_step_outputs[dataloader_idx].append(output_dict) + else: + self.test_step_outputs.append(output_dict) + return output_dict def multi_evaluation_epoch_end(self, outputs, dataloader_idx: int = 0, tag: str = 'val'): # Handle loss loss_mean = torch.stack([x[f'{tag}_loss'] for x in outputs]).mean() - output_dict = {f'{tag}_loss': loss_mean} tensorboard_logs = {f'{tag}_loss': loss_mean} # Handle metrics for this tag and dataloader_idx @@ -135,9 +147,7 @@ def multi_evaluation_epoch_end(self, outputs, dataloader_idx: int = 0, tag: str # Store for logs tensorboard_logs[f'{tag}_{name}'] = value - output_dict['log'] = tensorboard_logs - - return output_dict + return {f'{tag}_loss': loss_mean, 'log': tensorboard_logs} def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): return self.multi_evaluation_epoch_end(outputs, dataloader_idx, 'val') diff --git a/nemo/collections/asr/models/enhancement_models.py b/nemo/collections/asr/models/enhancement_models.py index a25bf882a23bf..e110d177cfda6 100644 --- a/nemo/collections/asr/models/enhancement_models.py +++ b/nemo/collections/asr/models/enhancement_models.py @@ -61,14 +61,24 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.decoder = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.decoder) if 'mixture_consistency' in self._cfg: + logging.debug('Using mixture consistency') self.mixture_consistency = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mixture_consistency) else: + logging.debug('Mixture consistency not used') self.mixture_consistency = None # Future enhancement: # If subclasses need to modify the config before calling super() # Check ASRBPE* classes do with their mixin + # Setup augmentation + if hasattr(self._cfg, 'channel_augment') and self._cfg.channel_augment is not None: + logging.debug('Using channel augmentation') + self.channel_augmentation = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.channel_augment) + else: + logging.debug('Channel augmentation not used') + self.channel_augmentation = None + # Setup optional Optimization flags self.setup_optimization_flags() @@ -125,7 +135,7 @@ def process( temporary_manifest_filepath = os.path.join(tmpdir, 'manifest.json') with open(temporary_manifest_filepath, 'w', encoding='utf-8') as fp: for audio_file in paths2audio_files: - entry = {'input_filepath': audio_file, 'duration': librosa.get_duration(filename=audio_file)} + entry = {'input_filepath': audio_file, 'duration': librosa.get_duration(path=audio_file)} fp.write(json.dumps(entry) + '\n') config = { @@ -397,17 +407,25 @@ def training_step(self, batch, batch_idx): if target_signal.ndim == 2: target_signal = target_signal.unsqueeze(1) + # Apply channel augmentation + if self.training and self.channel_augmentation is not None: + input_signal = self.channel_augmentation(input=input_signal) + + # Process input processed_signal, _ = self.forward(input_signal=input_signal, input_length=input_length) - loss_value = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length) + # Calculate the loss + loss = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length) + # Prepare logs tensorboard_logs = { - 'train_loss': loss_value, + 'train_loss': loss, 'learning_rate': self._optimizer.param_groups[0]['lr'], 'global_step': torch.tensor(self.trainer.global_step, dtype=torch.float32), } - return {'loss': loss_value, 'log': tensorboard_logs} + self.log_dict(tensorboard_logs, on_step=True, sync_dist=True) + return loss def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): input_signal, input_length, target_signal, target_length = batch @@ -419,11 +437,11 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = if target_signal.ndim == 2: target_signal = target_signal.unsqueeze(1) + # Process input processed_signal, _ = self.forward(input_signal=input_signal, input_length=input_length) - # Prepare output - loss_value = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length) - output_dict = {f'{tag}_loss': loss_value} + # Calculate the loss + loss = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length) # Update metrics if hasattr(self, 'metrics') and tag in self.metrics: @@ -433,8 +451,7 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = # Log global step self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32), sync_dist=True) - - return output_dict + return {f'{tag}_loss': loss} @classmethod def list_available_models(cls) -> Optional[PretrainedModelInfo]: diff --git a/nemo/collections/asr/modules/__init__.py b/nemo/collections/asr/modules/__init__.py index ecd430b56e6cd..0265d9e306878 100644 --- a/nemo/collections/asr/modules/__init__.py +++ b/nemo/collections/asr/modules/__init__.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo.collections.asr.modules.audio_modules import MaskBasedBeamformer, MaskEstimatorRNN, MaskReferenceChannel +from nemo.collections.asr.modules.audio_modules import ( + MaskBasedBeamformer, + MaskEstimatorFlexChannels, + MaskEstimatorRNN, + MaskReferenceChannel, +) from nemo.collections.asr.modules.audio_preprocessing import ( AudioToMelSpectrogramPreprocessor, AudioToMFCCPreprocessor, diff --git a/nemo/collections/asr/modules/audio_modules.py b/nemo/collections/asr/modules/audio_modules.py index e2218d2118cf2..89fac14e8214f 100644 --- a/nemo/collections/asr/modules/audio_modules.py +++ b/nemo/collections/asr/modules/audio_modules.py @@ -12,35 +12,36 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Tuple +from typing import Dict, List, Optional, Tuple import numpy as np import torch +from nemo.collections.asr.losses.audio_losses import temporal_mean +from nemo.collections.asr.modules.conformer_encoder import ConformerEncoder from nemo.collections.asr.parts.preprocessing.features import make_seq_mask_like +from nemo.collections.asr.parts.submodules.multichannel_modules import ( + ChannelAttentionPool, + ChannelAveragePool, + ParametricMultichannelWienerFilter, + TransformAttendConcatenate, + TransformAverageConcatenate, +) from nemo.collections.asr.parts.utils.audio_utils import db2mag, wrap_to_pi from nemo.core.classes import NeuralModule, typecheck from nemo.core.neural_types import FloatType, LengthsType, NeuralType, SpectrogramType from nemo.utils import logging from nemo.utils.decorators import experimental -try: - import torchaudio - - HAVE_TORCHAUDIO = True -except ModuleNotFoundError: - HAVE_TORCHAUDIO = False - - __all__ = [ 'MaskEstimatorRNN', + 'MaskEstimatorFlexChannels', 'MaskReferenceChannel', 'MaskBasedBeamformer', 'MaskBasedDereverbWPE', ] -@experimental class SpectrogramToMultichannelFeatures(NeuralModule): """Convert a complex-valued multi-channel spectrogram to multichannel features. @@ -50,32 +51,36 @@ class SpectrogramToMultichannelFeatures(NeuralModule): num_input_channels: Optional, provides the number of channels of the input signal. Used to infer the number of output channels. - magnitude_reduction: Reduction across channels. Default `None`, will calculate - magnitude of each channel. + mag_reduction: Reduction across channels. Default `None`, will calculate + magnitude of each channel. + mag_power: Optional, apply power on the magnitude. use_ipd: Use inter-channel phase difference (IPD). mag_normalization: Normalization for magnitude features ipd_normalization: Normalization for IPD features + eps: Small regularization constant. """ def __init__( self, num_subbands: int, num_input_channels: Optional[int] = None, - mag_reduction: Optional[str] = 'rms', + mag_reduction: Optional[str] = None, + mag_power: Optional[float] = None, use_ipd: bool = False, mag_normalization: Optional[str] = None, ipd_normalization: Optional[str] = None, + eps: float = 1e-8, ): super().__init__() self.mag_reduction = mag_reduction + self.mag_power = mag_power self.use_ipd = use_ipd - # TODO: normalization - if mag_normalization is not None: + if mag_normalization not in [None, 'mean', 'mean_var']: raise NotImplementedError(f'Unknown magnitude normalization {mag_normalization}') self.mag_normalization = mag_normalization - if ipd_normalization is not None: + if ipd_normalization not in [None, 'mean', 'mean_var']: raise NotImplementedError(f'Unknown ipd normalization {ipd_normalization}') self.ipd_normalization = ipd_normalization @@ -86,6 +91,19 @@ def __init__( self._num_features = num_subbands self._num_channels = num_input_channels if self.mag_reduction is None else 1 + self.eps = eps + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tnum_subbands: %d', num_subbands) + logging.debug('\tmag_reduction: %s', self.mag_reduction) + logging.debug('\tmag_power: %s', self.mag_power) + logging.debug('\tuse_ipd: %s', self.use_ipd) + logging.debug('\tmag_normalization: %s', self.mag_normalization) + logging.debug('\tipd_normalization: %s', self.ipd_normalization) + logging.debug('\teps: %f', self.eps) + logging.debug('\t_num_features: %s', self._num_features) + logging.debug('\t_num_channels: %s', self._num_channels) + @property def input_types(self) -> Dict[str, NeuralType]: """Returns definitions of module output ports. @@ -122,6 +140,98 @@ def num_channels(self) -> int: 'must be provided when constructing the object.' ) + @staticmethod + def get_mean_time_channel(input: torch.Tensor, input_length: Optional[torch.Tensor] = None) -> torch.Tensor: + """Calculate mean across time and channel dimensions. + + Args: + input: tensor with shape (B, C, F, T) + input_length: tensor with shape (B,) + + Returns: + Mean of `input` calculated across time and channel dimension + with shape (B, 1, F, 1) + """ + if input_length is None: + mean = torch.mean(input, dim=(-1, -3), keepdim=True) + else: + # temporal mean + mean = temporal_mean(input, input_length, keepdim=True) + # channel mean + mean = torch.mean(mean, dim=-3, keepdim=True) + + return mean + + @classmethod + def get_mean_std_time_channel( + cls, input: torch.Tensor, input_length: Optional[torch.Tensor] = None, eps: float = 1e-10 + ) -> torch.Tensor: + """Calculate mean and standard deviation across time and channel dimensions. + + Args: + input: tensor with shape (B, C, F, T) + input_length: tensor with shape (B,) + + Returns: + Mean and standard deviation of the `input` calculated across time and + channel dimension, each with shape (B, 1, F, 1). + """ + if input_length is None: + std, mean = torch.std_mean(input, dim=(-1, -3), unbiased=False, keepdim=True) + else: + mean = cls.get_mean_time_channel(input, input_length) + std = (input - mean).pow(2) + # temporal mean + std = temporal_mean(std, input_length, keepdim=True) + # channel mean + std = torch.mean(std, dim=-3, keepdim=True) + # final value + std = torch.sqrt(std.clamp(eps)) + + return mean, std + + @typecheck( + input_types={ + 'input': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "input_length": NeuralType(tuple('B'), LengthsType()), + }, + output_types={'output': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),}, + ) + def normalize_mean(self, input: torch.Tensor, input_length: torch.Tensor) -> torch.Tensor: + """Mean normalization for the input tensor. + + Args: + input: input tensor + input_length: valid length for each example + + Returns: + Mean normalized input. + """ + mean = self.get_mean_time_channel(input=input, input_length=input_length) + output = input - mean + return output + + @typecheck( + input_types={ + 'input': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "input_length": NeuralType(tuple('B'), LengthsType()), + }, + output_types={'output': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),}, + ) + def normalize_mean_var(self, input: torch.Tensor, input_length: torch.Tensor) -> torch.Tensor: + """Mean and variance normalization for the input tensor. + + Args: + input: input tensor + input_length: valid length for each example + + Returns: + Mean and variance normalized input. + """ + mean, std = self.get_mean_std_time_channel(input=input, input_length=input_length, eps=self.eps) + output = (input - mean) / std + return output + @typecheck() def forward(self, input: torch.Tensor, input_length: torch.Tensor) -> torch.Tensor: """Convert input batch of C-channel spectrograms into @@ -148,20 +258,30 @@ def forward(self, input: torch.Tensor, input_length: torch.Tensor) -> torch.Tens else: raise ValueError(f'Unexpected magnitude reduction {self.mag_reduction}') - if self.mag_normalization is not None: - mag = self.mag_normalization(mag) + if self.mag_power is not None: + mag = torch.pow(mag, self.mag_power) + + if self.mag_normalization == 'mean': + # normalize mean across channels and time steps + mag = self.normalize_mean(input=mag, input_length=input_length) + elif self.mag_normalization == 'mean_var': + mag = self.normalize_mean_var(input=mag, input_length=input_length) features = mag if self.use_ipd: - # Calculate IPD relative to average spec - spec_mean = torch.mean(input, axis=1, keepdim=True) + # Calculate IPD relative to the average spec + spec_mean = torch.mean(input, axis=1, keepdim=True) # channel average ipd = torch.angle(input) - torch.angle(spec_mean) # Modulo to [-pi, pi] ipd = wrap_to_pi(ipd) - if self.ipd_normalization is not None: - ipd = self.ipd_normalization(ipd) + if self.ipd_normalization == 'mean': + # normalize mean across channels and time steps + # mean across time + ipd = self.normalize_mean(input=ipd, input_length=input_length) + elif self.ipd_normalization == 'mean_var': + ipd = self.normalize_mean_var(input=ipd, input_length=input_length) # Concatenate to existing features features = torch.cat([features.expand(ipd.shape), ipd], axis=2) @@ -342,6 +462,245 @@ def forward(self, input: torch.Tensor, input_length: torch.Tensor) -> Tuple[torc return masks, output_length +class MaskEstimatorFlexChannels(NeuralModule): + """Estimate `num_outputs` masks from the input spectrogram + using stacked channel-wise and temporal layers. + + This model is using interlaved channel blocks and temporal blocks, and + it can process arbitrary number of input channels. + Default channel block is the transform-average-concatenate layer. + Default temporal block is the Conformer encoder. + Reduction from multichannel signal to single-channel signal is performed + after `channel_reduction_position` blocks. Only temporal blocks are used afterwards. + After the sequence of blocks, the output mask is computed using an additional + output temporal layer and a nonlinearity. + + References: + - Yoshioka et al, VarArray: Array-Geometry-Agnostic Continuous Speech Separation, 2022 + - Jukić et al, Flexible multichannel speech enhancement for noise-robust frontend, 2023 + + Args: + num_outputs: Number of output masks. + num_subbands: Number of subbands on the input spectrogram. + num_blocks: Number of blocks in the model. + channel_reduction_position: After this block, the signal will be reduced across channels. + channel_reduction_type: Reduction across channels: 'average' or 'attention' + channel_block_type: Block for channel processing: 'transform_average_concatenate' or 'transform_attend_concatenate' + temporal_block_type: Block for temporal processing: 'conformer_encoder' + temporal_block_num_layers: Number of layers for the temporal block + temporal_block_num_heads: Number of heads for the temporal block + temporal_block_dimension: The hidden size of the model + temporal_block_self_attention_model: Self attention model for the temporal block + temporal_block_att_context_size: Attention context size for the temporal block + mag_reduction: Channel-wise reduction for magnitude features + mag_power: Power to apply on magnitude features + use_ipd: Use inter-channel phase difference (IPD) features + mag_normalization: Normalize using mean ('mean') or mean and variance ('mean_var') + ipd_normalization: Normalize using mean ('mean') or mean and variance ('mean_var') + estimate_ref_channel: Estimate the output reference channel automatically + """ + + def __init__( + self, + num_outputs: int, + num_subbands: int, + num_blocks: int, + channel_reduction_position: int = -1, # if 0, apply before layer 0, if -1 apply at the end + channel_reduction_type: str = 'attention', + channel_block_type: str = 'transform_attend_concatenate', + temporal_block_type: str = 'conformer_encoder', + temporal_block_num_layers: int = 5, + temporal_block_num_heads: int = 4, + temporal_block_dimension: int = 128, + temporal_block_self_attention_model: str = 'rel_pos', + temporal_block_att_context_size: Optional[List[int]] = None, + num_input_channels: Optional[int] = None, + mag_reduction: str = 'abs_mean', + mag_power: Optional[float] = None, + use_ipd: bool = True, + mag_normalization: Optional[str] = None, + ipd_normalization: Optional[str] = None, + estimate_ref_channel: Optional[bool] = False, + ): + super().__init__() + + self.features = SpectrogramToMultichannelFeatures( + num_subbands=num_subbands, + num_input_channels=num_input_channels, + mag_reduction=mag_reduction, + mag_power=mag_power, + use_ipd=use_ipd, + mag_normalization=mag_normalization, + ipd_normalization=ipd_normalization, + ) + self.num_blocks = num_blocks + logging.debug('Total number of blocks: %d', self.num_blocks) + + # Channel reduction + if channel_reduction_position == -1: + # Apply reduction after the last layer + channel_reduction_position = num_blocks + + if channel_reduction_position > num_blocks: + raise ValueError( + f'Channel reduction position {channel_reduction_position} exceeds the number of blocks {num_blocks}' + ) + self.channel_reduction_position = channel_reduction_position + logging.debug('Channel reduction will be applied after block %d', self.channel_reduction_position) + + # Prepare processing blocks + self.channel_blocks = torch.nn.ModuleList() + self.temporal_blocks = torch.nn.ModuleList() + + for n in range(num_blocks): + logging.debug('Prepare block %d', n) + + channel_in_features = self.features.num_features if n == 0 else temporal_block_dimension + temporal_in_features = ( + self.features.num_features if n == channel_reduction_position == 0 else temporal_block_dimension + ) + + if n < channel_reduction_position: + logging.debug('Setup channel block %s', channel_block_type) + if channel_block_type == 'transform_average_concatenate': + channel_block = TransformAverageConcatenate( + in_features=channel_in_features, out_features=temporal_block_dimension + ) + elif channel_block_type == 'transform_attend_concatenate': + channel_block = TransformAttendConcatenate( + in_features=channel_in_features, out_features=temporal_block_dimension + ) + else: + raise ValueError(f'Unknown channel layer type: {channel_block_type}') + self.channel_blocks.append(channel_block) + + logging.debug('Setup temporal block %s', temporal_block_type) + if temporal_block_type == 'conformer_encoder': + temporal_block = ConformerEncoder( + feat_in=temporal_in_features, + n_layers=temporal_block_num_layers, + d_model=temporal_block_dimension, + subsampling_factor=1, + self_attention_model=temporal_block_self_attention_model, + att_context_size=temporal_block_att_context_size, + n_heads=temporal_block_num_heads, + ) + else: + raise ValueError(f'Unknown temporal block {temporal_block}.') + + self.temporal_blocks.append(temporal_block) + + logging.debug('Setup channel reduction %s', channel_reduction_type) + if channel_reduction_type == 'average': + # Mean across channel dimension + self.channel_reduction = ChannelAveragePool() + elif channel_reduction_type == 'attention': + # Attention across channel dimension + self.channel_reduction = ChannelAttentionPool(in_features=channel_in_features) + else: + raise ValueError(f'Unknown channel reduction type: {channel_reduction_type}') + + logging.debug('Setup %d output layers', num_outputs) + self.output_layers = torch.nn.ModuleList( + [ + ConformerEncoder( + feat_in=temporal_block_dimension, + n_layers=1, + d_model=temporal_block_dimension, + feat_out=num_subbands, + subsampling_factor=1, + self_attention_model=temporal_block_self_attention_model, + att_context_size=temporal_block_att_context_size, + n_heads=temporal_block_num_heads, + ) + for _ in range(num_outputs) + ] + ) + + # Output nonlinearity + self.output_nonlinearity = torch.nn.Sigmoid() + + @property + def input_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports. + """ + return { + "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "input_length": NeuralType(('B',), LengthsType()), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports. + """ + return { + "output": NeuralType(('B', 'C', 'D', 'T'), FloatType()), + "output_length": NeuralType(('B',), LengthsType()), + } + + @typecheck() + def forward(self, input: torch.Tensor, input_length: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Estimate `num_outputs` masks from the input spectrogram. + """ + # get input features from a complex-valued spectrogram, (B, C, F, T) + output, output_length = self.features(input=input, input_length=input_length) + + # batch and num channels + B, M = input.size(0), input.size(1) + + # process all blocks + for n in range(self.num_blocks): + if n < self.channel_reduction_position: + # apply multichannel block + output = self.channel_blocks[n](input=output) + # change to a single-stream format + F, T = output.size(-2), output.size(-1) + # (B, M, F, T) -> (B * M, F, T) + output = output.reshape(-1, F, T) + # adjust the lengths accordingly + output_length = output_length.repeat_interleave(M) + + elif n == self.channel_reduction_position: + # apply channel reduction + # (B, M, F, T) -> (B, F, T) + output = self.channel_reduction(input=output) + + # apply temporal model on each channel independently + with typecheck.disable_checks(): + # output is AcousticEncodedRepresentation, conformer encoder requires SpectrogramType + output, output_length = self.temporal_blocks[n](audio_signal=output, length=output_length) + + # if channel reduction has not been applied yet, go back to multichannel layout + if n < self.channel_reduction_position: + # back to multi-channel format with possibly a different number of features + T = output.size(-1) + # (B * M, F, T) -> (B, M, F, T) + output = output.reshape(B, M, -1, T) + # convert lengths from single-stream format to original multichannel + output_length = output_length[0:-1:M] + + if self.channel_reduction_position == self.num_blocks: + # apply channel reduction after the last layer + # (B, M, F, T) -> (B, F, T) + output = self.channel_reduction(input=output) + + # final mask for each output + masks = [] + for output_layer in self.output_layers: + # calculate mask + with typecheck.disable_checks(): + # output is AcousticEncodedRepresentation, conformer encoder requires SpectrogramType + mask, mask_length = output_layer(audio_signal=output, length=output_length) + mask = self.output_nonlinearity(mask) + # append to all masks + masks.append(mask) + + # stack masks along channel dimensions + masks = torch.stack(masks, dim=1) + + return masks, mask_length + + class MaskReferenceChannel(NeuralModule): """A simple mask processor which applies mask on ref_channel of the input signal. @@ -359,6 +718,11 @@ def __init__(self, ref_channel: int = 0, mask_min_db: float = -200, mask_max_db: self.mask_min = db2mag(mask_min_db) self.mask_max = db2mag(mask_max_db) + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tref_channel: %d', self.ref_channel) + logging.debug('\tmask_min: %f', self.mask_min) + logging.debug('\tmask_max: %f', self.mask_max) + @property def input_types(self) -> Dict[str, NeuralType]: """Returns definitions of module output ports. @@ -380,7 +744,7 @@ def output_types(self) -> Dict[str, NeuralType]: @typecheck() def forward( - self, input: torch.Tensor, input_length: torch.Tensor, mask: torch.Tensor + self, input: torch.Tensor, input_length: torch.Tensor, mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply mask on `ref_channel` of the input signal. This can be used to generate multi-channel output. @@ -407,36 +771,86 @@ class MaskBasedBeamformer(NeuralModule): Args: filter_type: string denoting the type of the filter. Defaults to `mvdr` - ref_channel: reference channel for processing + filter_beta: Parameter of the parameteric multichannel Wiener filter + filter_rank: Parameter of the parametric multichannel Wiener filter + filter_postfilter: Optional, postprocessing of the filter + ref_channel: Optional, reference channel. If None, it will be estimated automatically + ref_hard: If true, hard (one-hot) reference. If false, a soft reference + ref_hard_use_grad: If true, use straight-through gradient when using the hard reference + ref_subband_weighting: If true, use subband weighting when estimating reference channel + num_subbands: Optional, used to determine the parameter size for reference estimation mask_min_db: Threshold mask to a minimal value before applying it, defaults to -200dB mask_max_db: Threshold mask to a maximal value before applying it, defaults to 0dB + diag_reg: Optional, diagonal regularization for the multichannel filter + eps: Small regularization constant to avoid division by zero """ def __init__( self, filter_type: str = 'mvdr_souden', - ref_channel: int = 0, + filter_beta: float = 0.0, + filter_rank: str = 'one', + filter_postfilter: Optional[str] = None, + ref_channel: Optional[int] = 0, + ref_hard: bool = True, + ref_hard_use_grad: bool = False, + ref_subband_weighting: bool = False, + num_subbands: Optional[int] = None, mask_min_db: float = -200, mask_max_db: float = 0, + postmask_min_db: float = 0, + postmask_max_db: float = 0, + diag_reg: Optional[float] = 1e-6, + eps: float = 1e-8, ): - if not HAVE_TORCHAUDIO: - logging.error('Could not import torchaudio. Some features might not work.') - - raise ModuleNotFoundError( - "torchaudio is not installed but is necessary to instantiate a {self.__class__.__name__}" - ) - super().__init__() - self.ref_channel = ref_channel - self.filter_type = filter_type - if self.filter_type == 'mvdr_souden': - self.psd = torchaudio.transforms.PSD() - self.filter = torchaudio.transforms.SoudenMVDR() - else: + if filter_type not in ['pmwf', 'mvdr_souden']: raise ValueError(f'Unknown filter type {filter_type}') + + self.filter_type = filter_type + if self.filter_type == 'mvdr_souden' and filter_beta != 0: + logging.warning( + 'Using filter type %s: beta will be automatically set to zero (current beta %f) and rank to one (current rank %s).', + self.filter_type, + filter_beta, + filter_rank, + ) + filter_beta = 0.0 + filter_rank = 'one' + # Prepare filter + self.filter = ParametricMultichannelWienerFilter( + beta=filter_beta, + rank=filter_rank, + postfilter=filter_postfilter, + ref_channel=ref_channel, + ref_hard=ref_hard, + ref_hard_use_grad=ref_hard_use_grad, + ref_subband_weighting=ref_subband_weighting, + num_subbands=num_subbands, + diag_reg=diag_reg, + eps=eps, + ) # Mask thresholding + if mask_min_db >= mask_max_db: + raise ValueError( + f'Lower bound for the mask {mask_min_db}dB must be smaller than the upper bound {mask_max_db}dB' + ) self.mask_min = db2mag(mask_min_db) self.mask_max = db2mag(mask_max_db) + # Postmask thresholding + if postmask_min_db > postmask_max_db: + raise ValueError( + f'Lower bound for the postmask {postmask_min_db}dB must be smaller or equal to the upper bound {postmask_max_db}dB' + ) + self.postmask_min = db2mag(postmask_min_db) + self.postmask_max = db2mag(postmask_max_db) + + logging.debug('Initialized %s', self.__class__.__name__) + logging.debug('\tfilter_type: %s', self.filter_type) + logging.debug('\tmask_min: %e', self.mask_min) + logging.debug('\tmask_max: %e', self.mask_max) + logging.debug('\tpostmask_min: %e', self.postmask_min) + logging.debug('\tpostmask_max: %e', self.postmask_max) @property def input_types(self) -> Dict[str, NeuralType]: @@ -444,8 +858,9 @@ def input_types(self) -> Dict[str, NeuralType]: """ return { "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), - "input_length": NeuralType(('B',), LengthsType()), "mask": NeuralType(('B', 'C', 'D', 'T'), FloatType()), + "mask_undesired": NeuralType(('B', 'C', 'D', 'T'), FloatType(), optional=True), + "input_length": NeuralType(('B',), LengthsType(), optional=True), } @property @@ -454,45 +869,79 @@ def output_types(self) -> Dict[str, NeuralType]: """ return { "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), - "output_length": NeuralType(('B',), LengthsType()), + "output_length": NeuralType(('B',), LengthsType(), optional=True), } @typecheck() - def forward(self, input: torch.Tensor, input_length: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + def forward( + self, + input: torch.Tensor, + mask: torch.Tensor, + mask_undesired: Optional[torch.Tensor] = None, + input_length: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """Apply a mask-based beamformer to the input spectrogram. This can be used to generate multi-channel output. - If `mask` has `M` channels, the output will have `M` channels as well. + If `mask` has multiple channels, a multichannel filter is created for each mask, + and the output is concatenation of individual outputs along the channel dimension. + The total number of outputs is `num_masks * M`, where `M` is the number of channels + at the filter output. Args: input: Input signal complex-valued spectrogram, shape (B, C, F, N) + mask: Mask for M output signals, shape (B, num_masks, F, N) input_length: Length of valid entries along the time dimension, shape (B,) - mask: Mask for M output signals, shape (B, M, F, N) Returns: - M-channel output signal complex-valued spectrogram, shape (B, M, F, N) + Multichannel output signal complex-valued spectrogram, shape (B, num_masks * M, F, N) """ - # Apply threshold on the mask - mask = torch.clamp(mask, min=self.mask_min, max=self.mask_max) # Length mask - length_mask: torch.Tensor = make_seq_mask_like( - lengths=input_length, like=mask[:, 0, ...], time_dim=-1, valid_ones=False - ) - # Use each mask to generate an output at ref_channel - output = [] - for m in range(mask.size(1)): - # Prepare mask for the desired and the undesired signal - mask_desired = mask[:, m, ...].masked_fill(length_mask, 0.0) - mask_undesired = (1 - mask_desired).masked_fill(length_mask, 0.0) - # Calculate PSDs - psd_desired = self.psd(input, mask_desired) - psd_undesired = self.psd(input, mask_undesired) + if input_length is not None: + length_mask: torch.Tensor = make_seq_mask_like( + lengths=input_length, like=mask[:, 0, ...], time_dim=-1, valid_ones=False + ) + + # Use each mask to generate an output + output, num_masks = [], mask.size(1) + for m in range(num_masks): + # Desired signal mask + mask_d = mask[:, m, ...] + # Undesired signal mask + if mask_undesired is not None: + mask_u = mask_undesired[:, m, ...] + elif num_masks == 1: + # If a single mask is estimated, use the complement + mask_u = 1 - mask_d + else: + # Use sum of all other sources + mask_u = torch.sum(mask, dim=1) - mask_d + + # Threshold masks + mask_d = torch.clamp(mask_d, min=self.mask_min, max=self.mask_max) + mask_u = torch.clamp(mask_u, min=self.mask_min, max=self.mask_max) + + if input_length is not None: + mask_d = mask_d.masked_fill(length_mask, 0.0) + mask_u = mask_u.masked_fill(length_mask, 0.0) + # Apply filter - output_m = self.filter(input, psd_desired, psd_undesired, reference_channel=self.ref_channel) - output_m = output_m.masked_fill(length_mask, 0.0) - # Save the current output (B, F, N) + output_m = self.filter(input=input, mask_s=mask_d, mask_n=mask_u) + + # Optional: apply a postmask with min and max thresholds + if self.postmask_min < self.postmask_max: + postmask_m = torch.clamp(mask[:, m, ...], min=self.postmask_min, max=self.postmask_max) + output_m = output_m * postmask_m.unsqueeze(1) + + # Save the current output (B, M, F, T) output.append(output_m) - output = torch.stack(output, axis=1) + # Combine outputs along the channel dimension + # Each output is (B, M, F, T) + output = torch.concatenate(output, axis=1) + + # Apply masking + if input_length is not None: + output = output.masked_fill(length_mask[:, None, ...], 0.0) return output, input_length @@ -516,15 +965,19 @@ class estimates a multiple-input multiple-output prediction filter - Jukić et al, Group sparsity for MIMO speech dereverberation, 2015 """ - def __init__( - self, filter_length: int, prediction_delay: int, diag_reg: Optional[float] = 1e-8, eps: float = 1e-10 - ): + def __init__(self, filter_length: int, prediction_delay: int, diag_reg: Optional[float] = 1e-6, eps: float = 1e-8): super().__init__() self.filter_length = filter_length self.prediction_delay = prediction_delay self.diag_reg = diag_reg self.eps = eps + logging.debug('Initialized %s', self.__class__.__name__) + logging.debug('\tfilter_length: %d', self.filter_length) + logging.debug('\tprediction_delay: %d', self.prediction_delay) + logging.debug('\tdiag_reg: %g', self.diag_reg) + logging.debug('\teps: %g', self.eps) + @property def input_types(self) -> Dict[str, NeuralType]: """Returns definitions of module output ports. @@ -562,7 +1015,7 @@ def forward( as the input length. """ # Temporal weighting: average power over channels, shape (B, F, N) - weight = torch.mean(power, dim=1) + weight = torch.mean(power, dim=-3) # Use inverse power as the weight weight = 1 / (weight + self.eps) @@ -799,6 +1252,7 @@ class MaskBasedDereverbWPE(NeuralModule): mask_max_db: Threshold mask to a minimal value before applying it, defaults to 0dB diag_reg: Diagonal regularization for WPE eps: Small regularization constant + dtype: Data type for internal computations References: - Kinoshita et al, Neural network-based spectrum estimation for online WPE dereverberation, 2017 @@ -812,8 +1266,9 @@ def __init__( num_iterations: int = 1, mask_min_db: float = -200, mask_max_db: float = 0, - diag_reg: Optional[float] = 1e-8, - eps: float = 1e-10, + diag_reg: Optional[float] = 1e-6, + eps: float = 1e-8, + dtype: torch.dtype = torch.cdouble, ): super().__init__() # Filter setup @@ -824,6 +1279,16 @@ def __init__( # Mask thresholding self.mask_min = db2mag(mask_min_db) self.mask_max = db2mag(mask_max_db) + # Internal calculations + if dtype not in [torch.cfloat, torch.cdouble]: + raise ValueError(f'Unsupported dtype {dtype}, expecting torch.cfloat or torch.cdouble') + self.dtype = dtype + + logging.debug('Initialized %s', self.__class__.__name__) + logging.debug('\tnum_iterations: %s', self.num_iterations) + logging.debug('\tmask_min: %g', self.mask_min) + logging.debug('\tmask_max: %g', self.mask_max) + logging.debug('\tdtype: %s', self.dtype) @property def input_types(self) -> Dict[str, NeuralType]: @@ -851,19 +1316,21 @@ def forward( """Given an input signal `input`, apply the WPE dereverberation algoritm. Args: - input: C-channel complex-valued spectrogram, shape (B, C, F, N) + input: C-channel complex-valued spectrogram, shape (B, C, F, T) input_length: Optional length for each signal in the batch, shape (B,) - mask: Optional mask, shape (B, 1, F, N) or (B, C, F, N) + mask: Optional mask, shape (B, 1, F, N) or (B, C, F, T) Returns: Processed tensor with the same number of channels as the input, - shape (B, C, F, N). + shape (B, C, F, T). """ io_dtype = input.dtype with torch.cuda.amp.autocast(enabled=False): + output = input.to(dtype=self.dtype) - output = input.cdouble() + if not output.is_complex(): + raise RuntimeError(f'Expecting complex input, got {output.dtype}') for i in range(self.num_iterations): magnitude = torch.abs(output) @@ -891,7 +1358,7 @@ class MixtureConsistencyProjection(NeuralModule): eps: Small positive value for regularization Reference: - Wisdom et al., Differentiable consistency constraints for improved deep speech enhancement, 2018 + Wisdom et al, Differentiable consistency constraints for improved deep speech enhancement, 2018 """ def __init__(self, weighting: Optional[str] = None, eps: float = 1e-8): diff --git a/nemo/collections/asr/parts/submodules/multichannel_modules.py b/nemo/collections/asr/parts/submodules/multichannel_modules.py new file mode 100644 index 0000000000000..2d7d251d0cecb --- /dev/null +++ b/nemo/collections/asr/parts/submodules/multichannel_modules.py @@ -0,0 +1,778 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from typing import Callable, Optional + +import torch + +from nemo.collections.asr.parts.submodules.multi_head_attention import MultiHeadAttention +from nemo.core.classes import NeuralModule, typecheck +from nemo.core.neural_types import AudioSignal, FloatType, NeuralType, SpectrogramType +from nemo.utils import logging + +try: + import torchaudio + + HAVE_TORCHAUDIO = True +except ModuleNotFoundError: + HAVE_TORCHAUDIO = False + + +class ChannelAugment(NeuralModule): + """Randomly permute and selects a subset of channels. + + Args: + permute_channels (bool): Apply a random permutation of channels. + num_channels_min (int): Minimum number of channels to select. + num_channels_max (int): Max number of channels to select. + rng: Optional, random generator. + """ + + def __init__( + self, + permute_channels: bool = True, + num_channels_min: int = 1, + num_channels_max: Optional[int] = None, + rng: Optional[Callable] = None, + ): + super().__init__() + + self._rng = random.Random() if rng is None else rng + self.permute_channels = permute_channels + self.num_channels_min = num_channels_min + self.num_channels_max = num_channels_max + + if num_channels_max is not None and num_channels_min > num_channels_max: + raise ValueError( + f'Min number of channels {num_channels_min} cannot be greater than max number of channels {num_channels_max}' + ) + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tpermute_channels: %s', self.permute_channels) + logging.debug('\tnum_channels_min: %s', self.num_channels_min) + logging.debug('\tnum_channels_max: %s', self.num_channels_max) + + @property + def input_types(self): + """Returns definitions of module input types + """ + return { + 'input': NeuralType(('B', 'C', 'T'), AudioSignal()), + } + + @property + def output_types(self): + """Returns definitions of module output types + """ + return { + 'output': NeuralType(('B', 'C', 'T'), AudioSignal()), + } + + @typecheck() + @torch.no_grad() + def forward(self, input: torch.Tensor) -> torch.Tensor: + # Expecting (B, C, T) + assert input.ndim == 3, f'Expecting input with shape (B, C, T)' + num_channels_in = input.size(1) + + if num_channels_in < self.num_channels_min: + raise RuntimeError( + f'Number of input channels ({num_channels_in}) is smaller than the min number of output channels ({self.num_channels_min})' + ) + + num_channels_max = num_channels_in if self.num_channels_max is None else self.num_channels_max + num_channels_out = self._rng.randint(self.num_channels_min, num_channels_max) + + channels = list(range(num_channels_in)) + + if self.permute_channels: + self._rng.shuffle(channels) + + channels = channels[:num_channels_out] + + return input[:, channels, :] + + +class TransformAverageConcatenate(NeuralModule): + """Apply transform-average-concatenate across channels. + We're using a version from [2]. + + Args: + in_features: Number of input features + out_features: Number of output features + + References: + [1] Luo et al, End-to-end Microphone Permutation and Number Invariant Multi-channel Speech Separation, 2019 + [2] Yoshioka et al, VarArray: Array-Geometry-Agnostic Continuous Speech Separation, 2022 + """ + + def __init__(self, in_features: int, out_features: Optional[int] = None): + super().__init__() + + if out_features is None: + out_features = in_features + + # Parametrize with the total number of features (needs to be divisible by two due to stacking) + if out_features % 2 != 0: + raise ValueError(f'Number of output features should be divisible by two, currently set to {out_features}') + + self.transform_channel = torch.nn.Sequential( + torch.nn.Linear(in_features, out_features // 2, bias=False), torch.nn.ReLU() + ) + self.transform_average = torch.nn.Sequential( + torch.nn.Linear(in_features, out_features // 2, bias=False), torch.nn.ReLU() + ) + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tin_features: %d', in_features) + logging.debug('\tout_features: %d', out_features) + + @property + def input_types(self): + """Returns definitions of module input types + """ + return { + 'input': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + } + + @property + def output_types(self): + """Returns definitions of module output types + """ + return { + 'output': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + } + + @typecheck() + def forward(self, input: torch.Tensor) -> torch.Tensor: + """ + Args: + input: shape (B, M, in_features, T) + + Returns: + Output tensor with shape shape (B, M, out_features, T) + """ + B, M, F, T = input.shape + + # (B, M, F, T) -> (B, T, M, F) + input = input.permute(0, 3, 1, 2) + + # transform and average across channels + average = self.transform_average(input) + average = torch.mean(average, dim=-2, keepdim=True) + # view with the number of channels expanded to M + average = average.expand(-1, -1, M, -1) + + # transform each channel + transform = self.transform_channel(input) + + # concatenate along feature dimension + output = torch.cat([transform, average], dim=-1) + + # Return to the original layout + # (B, T, M, F) -> (B, M, F, T) + output = output.permute(0, 2, 3, 1) + + return output + + +class TransformAttendConcatenate(NeuralModule): + """Apply transform-attend-concatenate across channels. + The output is a concatenation of transformed channel and MHA + over channels. + + Args: + in_features: Number of input features + out_features: Number of output features + n_head: Number of heads for the MHA module + dropout_rate: Dropout rate for the MHA module + + References: + - Jukić et al, Flexible multichannel speech enhancement for noise-robust frontend, 2023 + """ + + def __init__(self, in_features: int, out_features: Optional[int] = None, n_head: int = 4, dropout_rate: float = 0): + super().__init__() + + if out_features is None: + out_features = in_features + + # Parametrize with the total number of features (needs to be divisible by two due to stacking) + if out_features % 2 != 0: + raise ValueError(f'Number of output features should be divisible by two, currently set to {out_features}') + + self.transform_channel = torch.nn.Sequential( + torch.nn.Linear(in_features, out_features // 2, bias=False), torch.nn.ReLU() + ) + self.transform_attend = torch.nn.Sequential( + torch.nn.Linear(in_features, out_features // 2, bias=False), torch.nn.ReLU() + ) + self.attention = MultiHeadAttention(n_head=n_head, n_feat=out_features // 2, dropout_rate=dropout_rate) + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tin_features: %d', in_features) + logging.debug('\tout_features: %d', out_features) + logging.debug('\tn_head: %d', n_head) + logging.debug('\tdropout_rate: %f', dropout_rate) + + @property + def input_types(self): + """Returns definitions of module input types + """ + return { + 'input': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + } + + @property + def output_types(self): + """Returns definitions of module output types + """ + return { + 'output': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + } + + @typecheck() + def forward(self, input: torch.Tensor) -> torch.Tensor: + """ + Args: + input: shape (B, M, in_features, T) + + Returns: + Output tensor with shape shape (B, M, out_features, T) + """ + B, M, F, T = input.shape + + # (B, M, F, T) -> (B, T, M, F) + input = input.permute(0, 3, 1, 2) + input = input.reshape(B * T, M, F) + + # transform each channel + transform = self.transform_channel(input) + + # attend + attend = self.transform_attend(input) + # attention across channels + attend = self.attention(query=attend, key=attend, value=attend, mask=None) + + # concatenate along feature dimension + output = torch.cat([transform, attend], dim=-1) + + # return to the original layout + output = output.view(B, T, M, -1) + + # (B, T, M, num_features) -> (B, M, num_features, T) + output = output.permute(0, 2, 3, 1) + + return output + + +class ChannelAveragePool(NeuralModule): + """Apply average pooling across channels. + """ + + def __init__(self): + super().__init__() + logging.debug('Initialized %s', self.__class__.__name__) + + @property + def input_types(self): + """Returns definitions of module input types + """ + return { + 'input': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + } + + @property + def output_types(self): + """Returns definitions of module output types + """ + return { + 'output': NeuralType(('B', 'D', 'T'), SpectrogramType()), + } + + @typecheck() + def forward(self, input: torch.Tensor) -> torch.Tensor: + """ + Args: + input: shape (B, M, F, T) + + Returns: + Output tensor with shape shape (B, F, T) + """ + return torch.mean(input, dim=-3) + + +class ChannelAttentionPool(NeuralModule): + """Use attention pooling to aggregate information across channels. + First apply MHA across channels and then apply averaging. + + Args: + in_features: Number of input features + out_features: Number of output features + n_head: Number of heads for the MHA module + dropout_rate: Dropout rate for the MHA module + + References: + - Wang et al, Neural speech separation using sparially distributed microphones, 2020 + - Jukić et al, Flexible multichannel speech enhancement for noise-robust frontend, 2023 + """ + + def __init__(self, in_features: int, n_head: int = 1, dropout_rate: float = 0): + super().__init__() + self.in_features = in_features + self.attention = MultiHeadAttention(n_head=n_head, n_feat=in_features, dropout_rate=dropout_rate) + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tin_features: %d', in_features) + logging.debug('\tnum_heads: %d', n_head) + logging.debug('\tdropout_rate: %d', dropout_rate) + + @property + def input_types(self): + """Returns definitions of module input types + """ + return { + 'input': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + } + + @property + def output_types(self): + """Returns definitions of module output types + """ + return { + 'output': NeuralType(('B', 'D', 'T'), SpectrogramType()), + } + + @typecheck() + def forward(self, input: torch.Tensor) -> torch.Tensor: + """ + Args: + input: shape (B, M, F, T) + + Returns: + Output tensor with shape shape (B, F, T) + """ + B, M, F, T = input.shape + + # (B, M, F, T) -> (B, T, M, F) + input = input.permute(0, 3, 1, 2) + input = input.reshape(B * T, M, F) + + # attention across channels + output = self.attention(query=input, key=input, value=input, mask=None) + + # return to the original layout + output = output.view(B, T, M, -1) + + # (B, T, M, num_features) -> (B, M, out_features, T) + output = output.permute(0, 2, 3, 1) + + # average across channels + output = torch.mean(output, axis=-3) + + return output + + +class ParametricMultichannelWienerFilter(NeuralModule): + """Parametric multichannel Wiener filter, with an adjustable + tradeoff between noise reduction and speech distortion. + It supports automatic reference channel selection based + on the estimated output SNR. + + Args: + beta: Parameter of the parameteric filter, tradeoff between noise reduction + and speech distortion (0: MVDR, 1: MWF). + rank: Rank assumption for the speech covariance matrix. + postfilter: Optional postfilter. If None, no postfilter is applied. + ref_channel: Optional, reference channel. If None, it will be estimated automatically. + ref_hard: If true, estimate a hard (one-hot) reference. If false, a soft reference. + ref_hard_use_grad: If true, use straight-through gradient when using the hard reference + ref_subband_weighting: If true, use subband weighting when estimating reference channel + num_subbands: Optional, used to determine the parameter size for reference estimation + diag_reg: Optional, diagonal regularization for the multichannel filter + eps: Small regularization constant to avoid division by zero + + References: + - Souden et al, On Optimal Frequency-Domain Multichannel Linear Filtering for Noise Reduction, 2010 + """ + + def __init__( + self, + beta: float = 1.0, + rank: str = 'one', + postfilter: Optional[str] = None, + ref_channel: Optional[int] = None, + ref_hard: bool = True, + ref_hard_use_grad: bool = True, + ref_subband_weighting: bool = False, + num_subbands: Optional[int] = None, + diag_reg: Optional[float] = 1e-6, + eps: float = 1e-8, + ): + if not HAVE_TORCHAUDIO: + logging.error('Could not import torchaudio. Some features might not work.') + + raise ModuleNotFoundError( + "torchaudio is not installed but is necessary to instantiate a {self.__class__.__name__}" + ) + + super().__init__() + + # Parametric filter + # 0=MVDR, 1=MWF + self.beta = beta + + # Rank + # Assumed rank for the signal covariance matrix (psd_s) + self.rank = rank + + if self.rank == 'full' and self.beta == 0: + raise ValueError(f'Rank {self.rank} is not compatible with beta {self.beta}.') + + # Postfilter, applied on the output of the multichannel filter + if postfilter not in [None, 'ban']: + raise ValueError(f'Postfilter {postfilter} is not supported.') + self.postfilter = postfilter + + # Regularization + if diag_reg is not None and diag_reg < 0: + raise ValueError(f'Diagonal regularization {diag_reg} must be positive.') + self.diag_reg = diag_reg + + if eps <= 0: + raise ValueError(f'Epsilon {eps} must be positive.') + self.eps = eps + + # PSD estimator + self.psd = torchaudio.transforms.PSD() + + # Reference channel + self.ref_channel = ref_channel + if self.ref_channel == 'max_snr': + self.ref_estimator = ReferenceChannelEstimatorSNR( + hard=ref_hard, + hard_use_grad=ref_hard_use_grad, + subband_weighting=ref_subband_weighting, + num_subbands=num_subbands, + eps=eps, + ) + else: + self.ref_estimator = None + # Flag to determine if the filter is MISO or MIMO + self.is_mimo = self.ref_channel is None + + logging.debug('Initialized %s', self.__class__.__name__) + logging.debug('\tbeta: %f', self.beta) + logging.debug('\trank: %s', self.rank) + logging.debug('\tpostfilter: %s', self.postfilter) + logging.debug('\tdiag_reg: %g', self.diag_reg) + logging.debug('\teps: %g', self.eps) + logging.debug('\tref_channel: %s', self.ref_channel) + logging.debug('\tis_mimo: %s', self.is_mimo) + + @staticmethod + def trace(x: torch.Tensor, keepdim: bool = False) -> torch.Tensor: + """Calculate trace of matrix slices over the last + two dimensions in the input tensor. + + Args: + x: tensor, shape (..., C, C) + + Returns: + Trace for each (C, C) matrix. shape (...) + """ + trace = torch.diagonal(x, dim1=-2, dim2=-1).sum(-1) + if keepdim: + trace = trace.unsqueeze(-1).unsqueeze(-1) + return trace + + def apply_diag_reg(self, psd: torch.Tensor) -> torch.Tensor: + """Apply diagonal regularization on psd. + + Args: + psd: tensor, shape (..., C, C) + + Returns: + Tensor, same shape as input. + """ + # Regularization: diag_reg * trace(psd) + eps + diag_reg = self.diag_reg * self.trace(psd).real + self.eps + + # Apply regularization + psd = psd + torch.diag_embed(diag_reg.unsqueeze(-1) * torch.ones(psd.shape[-1], device=psd.device)) + + return psd + + def apply_filter(self, input: torch.Tensor, filter: torch.Tensor) -> torch.Tensor: + """Apply the MIMO filter on the input. + + Args: + input: batch with C input channels, shape (B, C, F, T) + filter: batch of C-input, M-output filters, shape (B, F, C, M) + + Returns: + M-channel filter output, shape (B, M, F, T) + """ + if not filter.is_complex(): + raise TypeError(f'Expecting complex-valued filter, found {filter.dtype}') + + if not input.is_complex(): + raise TypeError(f'Expecting complex-valued input, found {input.dtype}') + + if filter.ndim != 4 or filter.size(-2) != input.size(-3) or filter.size(-3) != input.size(-2): + raise ValueError(f'Filter shape {filter.shape}, not compatible with input shape {input.shape}') + + output = torch.einsum('bfcm,bcft->bmft', filter.conj(), input) + + return output + + def apply_ban(self, input: torch.Tensor, filter: torch.Tensor, psd_n: torch.Tensor) -> torch.Tensor: + """Apply blind analytic normalization postfilter. Note that this normalization has been + derived for the GEV beamformer in [1]. More specifically, the BAN postfilter aims to scale GEV + to satisfy the distortionless constraint and the final analytical expression is derived using + an assumption on the norm of the transfer function. + However, this may still be useful in some instances. + + Args: + input: batch with M output channels (B, M, F, T) + filter: batch of C-input, M-output filters, shape (B, F, C, M) + psd_n: batch of noise PSDs, shape (B, F, C, C) + + Returns: + Filtere input, shape (B, M, F, T) + + References: + - Warsitz and Haeb-Umbach, Blind Acoustic Beamforming Based on Generalized Eigenvalue Decomposition, 2007 + """ + # number of input channel, used to normalize the numerator + num_inputs = filter.size(-2) + numerator = torch.einsum('bfcm,bfci,bfij,bfjm->bmf', filter.conj(), psd_n, psd_n, filter) + numerator = torch.sqrt(numerator.abs() / num_inputs) + + denominator = torch.einsum('bfcm,bfci,bfim->bmf', filter.conj(), psd_n, filter) + denominator = denominator.abs() + + # Scalar filter per output channel, frequency and batch + # shape (B, M, F) + ban = numerator / (denominator + self.eps) + + input = ban[..., None] * input + + return input + + @property + def input_types(self): + """Returns definitions of module input types + """ + return { + 'input': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + 'mask_s': NeuralType(('B', 'D', 'T'), FloatType()), + 'mask_n': NeuralType(('B', 'D', 'T'), FloatType()), + } + + @property + def output_types(self): + """Returns definitions of module output types + """ + return { + 'output': NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + } + + @typecheck() + def forward(self, input: torch.Tensor, mask_s: torch.Tensor, mask_n: torch.Tensor) -> torch.Tensor: + """Return processed signal. + The output has either one channel (M=1) if a ref_channel is selected, + or the same number of channels as the input (M=C) if ref_channel is None. + + Args: + input: Input signal, complex tensor with shape (B, C, F, T) + mask_s: Mask for the desired signal, shape (B, F, T) + mask_n: Mask for the undesired noise, shape (B, F, T) + + Returns: + Processed signal, shape (B, M, F, T) + """ + iodtype = input.dtype + + with torch.cuda.amp.autocast(enabled=False): + # Convert to double + input = input.cdouble() + mask_s = mask_s.double() + mask_n = mask_n.double() + + # Calculate signal statistics + psd_s = self.psd(input, mask_s) + psd_n = self.psd(input, mask_n) + + if self.rank == 'one': + # Calculate filter W using (18) in [1] + # Diagonal regularization + if self.diag_reg: + psd_n = self.apply_diag_reg(psd_n) + + # MIMO filter + # (B, F, C, C) + W = torch.linalg.solve(psd_n, psd_s) + lam = self.trace(W, keepdim=True).real + W = W / (self.beta + lam + self.eps) + elif self.rank == 'full': + # Calculate filter W using (15) in [1] + psd_sn = psd_s + self.beta * psd_n + + if self.diag_reg: + psd_sn = self.apply_diag_reg(psd_sn) + + # MIMO filter + # (B, F, C, C) + W = torch.linalg.solve(psd_sn, psd_s) + else: + raise RuntimeError(f'Unexpected rank {self.rank}') + + if torch.jit.isinstance(self.ref_channel, int): + # Fixed ref channel + # (B, F, C, 1) + W = W[..., self.ref_channel].unsqueeze(-1) + elif self.ref_estimator is not None: + # Estimate ref channel tensor (one-hot or soft across C) + # (B, C) + ref_channel_tensor = self.ref_estimator(W=W, psd_s=psd_s, psd_n=psd_n).to(W.dtype) + # Weighting across channels + # (B, F, C, 1) + W = torch.sum(W * ref_channel_tensor[:, None, None, :], dim=-1, keepdim=True) + + output = self.apply_filter(input=input, filter=W) + + # Optional: postfilter + if self.postfilter == 'ban': + output = self.apply_ban(input=output, filter=W, psd_n=psd_n) + + return output.to(iodtype) + + +class ReferenceChannelEstimatorSNR(NeuralModule): + """Estimate a reference channel by selecting the reference + that maximizes the output SNR. It returns one-hot encoded + vector or a soft reference. + + A straight-through estimator is used for gradient when using + hard reference. + + Args: + hard: If true, use hard estimate of ref channel. + If false, use a soft estimate across channels. + hard_use_grad: Use straight-through estimator for + the gradient. + subband_weighting: If true, use subband weighting when + adding across subband SNRs. If false, use average + across subbands. + + References: + Boeddeker et al, Front-End Processing for the CHiME-5 Dinner Party Scenario, 2018 + """ + + def __init__( + self, + hard: bool = True, + hard_use_grad: bool = True, + subband_weighting: bool = False, + num_subbands: Optional[int] = None, + eps: float = 1e-8, + ): + super().__init__() + + self.hard = hard + self.hard_use_grad = hard_use_grad + self.subband_weighting = subband_weighting + self.eps = eps + + if subband_weighting and num_subbands is None: + raise ValueError(f'Number of subbands must be provided when using subband_weighting={subband_weighting}.') + # Subband weighting + self.weight_s = torch.nn.Parameter(torch.ones(num_subbands)) if subband_weighting else None + self.weight_n = torch.nn.Parameter(torch.ones(num_subbands)) if subband_weighting else None + + logging.debug('Initialized %s', self.__class__.__name__) + logging.debug('\thard: %d', self.hard) + logging.debug('\thard_use_grad: %d', self.hard_use_grad) + logging.debug('\tsubband_weighting: %d', self.subband_weighting) + logging.debug('\tnum_subbands: %s', num_subbands) + logging.debug('\teps: %e', self.eps) + + @property + def input_types(self): + """Returns definitions of module input types + """ + return { + 'W': NeuralType(('B', 'D', 'C', 'C'), SpectrogramType()), + 'psd_s': NeuralType(('B', 'D', 'C', 'C'), SpectrogramType()), + 'psd_n': NeuralType(('B', 'D', 'C', 'C'), SpectrogramType()), + } + + @property + def output_types(self): + """Returns definitions of module output types + """ + return { + 'output': NeuralType(('B', 'C'), FloatType()), + } + + @typecheck() + def forward(self, W: torch.Tensor, psd_s: torch.Tensor, psd_n: torch.Tensor) -> torch.Tensor: + """ + Args: + W: Multichannel input multichannel output filter, shape (B, F, C, M), where + C is the number of input channels and M is the number of output channels + psd_s: Covariance for the signal, shape (B, F, C, C) + psd_n: Covariance for the noise, shape (B, F, C, C) + + Returns: + One-hot or soft reference channel, shape (B, M) + """ + if self.subband_weighting: + # (B, F, M) + pow_s = torch.einsum('...jm,...jk,...km->...m', W.conj(), psd_s, W).abs() + pow_n = torch.einsum('...jm,...jk,...km->...m', W.conj(), psd_n, W).abs() + + # Subband-weighting + # (B, F, M) -> (B, M) + pow_s = torch.sum(pow_s * self.weight_s.softmax(dim=0).unsqueeze(1), dim=-2) + pow_n = torch.sum(pow_n * self.weight_n.softmax(dim=0).unsqueeze(1), dim=-2) + else: + # Sum across f as well + # (B, F, C, M), (B, F, C, C), (B, F, C, M) -> (B, M) + pow_s = torch.einsum('...fjm,...fjk,...fkm->...m', W.conj(), psd_s, W).abs() + pow_n = torch.einsum('...fjm,...fjk,...fkm->...m', W.conj(), psd_n, W).abs() + + # Estimated SNR per channel (B, C) + snr = pow_s / (pow_n + self.eps) + snr = 10 * torch.log10(snr + self.eps) + + # Soft reference + ref_soft = snr.softmax(dim=-1) + + if self.hard: + _, idx = ref_soft.max(dim=-1, keepdim=True) + ref_hard = torch.zeros_like(snr).scatter(-1, idx, 1.0) + if self.hard_use_grad: + # Straight-through for gradient + # Propagate ref_soft gradient, as if thresholding is identity + ref = ref_hard - ref_soft.detach() + ref_soft + else: + # No gradient + ref = ref_hard + else: + ref = ref_soft + + return ref diff --git a/nemo/collections/asr/parts/utils/audio_utils.py b/nemo/collections/asr/parts/utils/audio_utils.py index 80dfc74950a5f..8188dbed003b5 100644 --- a/nemo/collections/asr/parts/utils/audio_utils.py +++ b/nemo/collections/asr/parts/utils/audio_utils.py @@ -412,7 +412,7 @@ def calculate_sdr_numpy( convolution_filter_length: Optional[int] = None, remove_mean: bool = True, sdr_max: Optional[float] = None, - eps: float = 1e-10, + eps: float = 1e-8, ) -> float: """Calculate signal-to-distortion ratio. @@ -519,7 +519,7 @@ def convmtx_mc_numpy(x: np.ndarray, filter_length: int, delay: int = 0, n_steps: return np.hstack(mc_mtx) -def scale_invariant_target_numpy(estimate: np.ndarray, target: np.ndarray, eps: float = 1e-10) -> np.ndarray: +def scale_invariant_target_numpy(estimate: np.ndarray, target: np.ndarray, eps: float = 1e-8) -> np.ndarray: """Calculate convolution-invariant target for a given estimated signal. Calculate scaled target obtained by solving @@ -543,7 +543,7 @@ def scale_invariant_target_numpy(estimate: np.ndarray, target: np.ndarray, eps: def convolution_invariant_target_numpy( - estimate: np.ndarray, target: np.ndarray, filter_length, diag_reg: float = 1e-8, eps: float = 1e-10 + estimate: np.ndarray, target: np.ndarray, filter_length, diag_reg: float = 1e-6, eps: float = 1e-8 ) -> np.ndarray: """Calculate convolution-invariant target for a given estimated signal. diff --git a/requirements/requirements_asr.txt b/requirements/requirements_asr.txt index 011862ad723ba..8df86fa4679ab 100644 --- a/requirements/requirements_asr.txt +++ b/requirements/requirements_asr.txt @@ -5,7 +5,7 @@ ipywidgets jiwer kaldi-python-io kaldiio -librosa>=0.9.0 +librosa>=0.10.0 marshmallow matplotlib packaging diff --git a/tests/collections/asr/test_asr_part_submodules_multichannel.py b/tests/collections/asr/test_asr_part_submodules_multichannel.py new file mode 100644 index 0000000000000..f53d140277319 --- /dev/null +++ b/tests/collections/asr/test_asr_part_submodules_multichannel.py @@ -0,0 +1,157 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from nemo.collections.asr.parts.submodules.multichannel_modules import ( + ChannelAttentionPool, + ChannelAugment, + ChannelAveragePool, + TransformAttendConcatenate, + TransformAverageConcatenate, +) + + +class TestChannelAugment: + @pytest.mark.unit + @pytest.mark.parametrize('num_channels', [1, 2, 6]) + def test_channel_selection(self, num_channels): + """Test getting a fixed number of channels without randomization. + The first few channels will always be selected. + """ + num_examples = 100 + batch_size = 4 + num_samples = 100 + + uut = ChannelAugment(permute_channels=False, num_channels_min=1, num_channels_max=num_channels) + + for n in range(num_examples): + input = torch.rand(batch_size, num_channels, num_samples) + output = uut(input=input) + + num_channels_out = output.size(-2) + + assert torch.allclose( + output, input[:, :num_channels_out, :] + ), f'Failed for num_channels_out {num_channels_out}, example {n}' + + +class TestTAC: + @pytest.mark.unit + @pytest.mark.parametrize('num_channels', [1, 2, 6]) + def test_average(self, num_channels): + """Test transform-average-concatenate. + """ + num_examples = 10 + batch_size = 4 + in_features = 128 + out_features = 96 + num_frames = 20 + + uut = TransformAverageConcatenate(in_features=in_features, out_features=out_features) + + for n in range(num_examples): + input = torch.rand(batch_size, num_channels, in_features, num_frames) + output = uut(input=input) + + # Dimensions must match + assert output.shape == ( + batch_size, + num_channels, + out_features, + num_frames, + ), f'Example {n}: output shape {output.shape} not matching the expected ({batch_size}, {num_channels}, {out_features}, {num_frames})' + + # Second half of features must be the same for all channels (concatenated average) + if num_channels > 1: + # reference + avg_ref = output[:, 0, out_features // 2 :, :] + for m in range(1, num_channels): + assert torch.allclose( + output[:, m, out_features // 2 :, :], avg_ref + ), f'Example {n}: average not matching' + + @pytest.mark.unit + @pytest.mark.parametrize('num_channels', [1, 2, 6]) + def test_attend(self, num_channels): + """Test transform-attend-concatenate. + Second half of features is different across channels, since we're using attention, so + we check only for shape. + """ + num_examples = 10 + batch_size = 4 + in_features = 128 + out_features = 96 + num_frames = 20 + + uut = TransformAttendConcatenate(in_features=in_features, out_features=out_features) + + for n in range(num_examples): + input = torch.rand(batch_size, num_channels, in_features, num_frames) + output = uut(input=input) + + # Dimensions must match + assert output.shape == ( + batch_size, + num_channels, + out_features, + num_frames, + ), f'Example {n}: output shape {output.shape} not matching the expected ({batch_size}, {num_channels}, {out_features}, {num_frames})' + + +class TestChannelPool: + @pytest.mark.unit + @pytest.mark.parametrize('num_channels', [1, 2, 6]) + def test_average(self, num_channels): + """Test average channel pooling. + """ + num_examples = 10 + batch_size = 4 + in_features = 128 + num_frames = 20 + + uut = ChannelAveragePool() + + for n in range(num_examples): + input = torch.rand(batch_size, num_channels, in_features, num_frames) + output = uut(input=input) + + # Dimensions must match + assert torch.allclose( + output, torch.mean(input, dim=1) + ), f'Example {n}: output not matching the expected average' + + @pytest.mark.unit + @pytest.mark.parametrize('num_channels', [2, 6]) + def test_attention(self, num_channels): + """Test attention for channel pooling. + """ + num_examples = 10 + batch_size = 4 + in_features = 128 + num_frames = 20 + + uut = ChannelAttentionPool(in_features=in_features) + + for n in range(num_examples): + input = torch.rand(batch_size, num_channels, in_features, num_frames) + output = uut(input=input) + + # Dimensions must match + assert output.shape == ( + batch_size, + in_features, + num_frames, + ), f'Example {n}: output shape {output.shape} not matching the expected ({batch_size}, {in_features}, {num_frames})'