Skip to content

Commit

Permalink
Adding a mask estimator which can process an arbitrary number of chan…
Browse files Browse the repository at this point in the history
…nels

Signed-off-by: Ante Jukić <ajukic@nvidia.com>
  • Loading branch information
anteju committed Sep 20, 2023
1 parent c90d4dd commit 1cb25f8
Show file tree
Hide file tree
Showing 12 changed files with 1,698 additions and 109 deletions.
146 changes: 146 additions & 0 deletions examples/audio_tasks/conf/beamforming_flex_channels.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# This configuration contains the exemplary values for training a multichannel speech enhancement model with a mask-based beamformer.
#
name: beamforming_flex_channels

model:
sample_rate: 16000
skip_nan_grad: false
num_outputs: 1

train_ds:
manifest_filepath: ???
input_key: audio_filepath # key of the input signal path in the manifest
input_channel_selector: null # load all channels from the input file
target_key: target_anechoic_filepath # key of the target signal path in the manifest
target_channel_selector: 0 # load only the first channel from the target file
audio_duration: 4.0 # in seconds, audio segment duration for training
random_offset: true # if the file is longer than audio_duration, use random offset to select a subsegment
min_duration: ${model.train_ds.audio_duration}
batch_size: 16 # batch size may be increased based on the available memory
shuffle: true
num_workers: 16
pin_memory: true

validation_ds:
manifest_filepath: ???
input_key: audio_filepath # key of the input signal path in the manifest
input_channel_selector: null # load all channels from the input file
target_key: target_anechoic_filepath # key of the target signal path in the manifest
target_channel_selector: 0 # load only the first channel from the target file
batch_size: 8
shuffle: false
num_workers: 8
pin_memory: true

channel_augment:
_target_: nemo.collections.asr.parts.submodules.multichannel_modules.ChannelAugment
num_channels_min: 2 # minimal number of channels selected for each batch
num_channels_max: null # max number of channels is determined by the batch size
permute_channels: true

encoder:
_target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram
fft_length: 512 # Length of the window and FFT for calculating spectrogram
hop_length: 256 # Hop length for calculating spectrogram

decoder:
_target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio
fft_length: ${model.encoder.fft_length}
hop_length: ${model.encoder.hop_length}

mask_estimator:
_target_: nemo.collections.asr.modules.audio_modules.MaskEstimatorFlexChannels
num_outputs: ${model.num_outputs} # number of output masks
num_subbands: 257 # number of subbands for the input spectrogram
num_blocks: 5 # number of blocks in the model
channel_reduction_position: 3 # 0-indexed, apply channel reduction before this block
channel_reduction_type: average # channel-wise reduction
channel_block_type: transform_average_concatenate # channel block
temporal_block_type: conformer_encoder # temporal block
temporal_block_num_layers: 5 # number of layers for the temporal block
temporal_block_num_heads: 4 # number of heads for the temporal block
temporal_block_dimension: 128 # the hidden size of the temporal block
mag_reduction: null # channel-wise reduction of magnitude
mag_normalization: mean_var # normalization using mean and variance
use_ipd: true # use inter-channel phase difference
ipd_normalization: mean # mean normalization

mask_processor:
# Mask-based multi-channel processor
_target_: nemo.collections.asr.modules.audio_modules.MaskBasedBeamformer
filter_type: pmwf # parametric multichannel wiener filter
filter_beta: 0.0 # mvdr
filter_rank: one
ref_channel: max_snr # select reference channel by maximizing estimated SNR
ref_hard: 1 # a one-hot reference. If false, a soft estimate across channels is used.
ref_hard_use_grad: false # use straight-through gradient when using hard reference
ref_subband_weighting: false # use subband weighting for reference estimation
num_subbands: ${model.mask_estimator.num_subbands}

loss:
_target_: nemo.collections.asr.losses.SDRLoss
convolution_invariant: true # convolution-invariant loss
sdr_max: 30 # soft threshold for SDR

metrics:
val:
sdr_0:
_target_: torchmetrics.audio.SignalDistortionRatio
channel: 0 # evaluate only on channel 0, if there are multiple outputs

optim:
name: adamw
lr: 1e-4
# optimizer arguments
betas: [0.9, 0.98]
weight_decay: 1e-3

# scheduler setup
sched:
name: CosineAnnealing
# scheduler config override
warmup_steps: 10000
warmup_ratio: null
min_lr: 1e-6

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: -1
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: null
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 25 # Interval of logging.
enable_progress_bar: true
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_loss"
mode: "min"
save_top_k: 5
always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints

resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.pyth
# you need to set these two to true to continue the training
resume_if_exists: false
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
9 changes: 9 additions & 0 deletions examples/audio_tasks/process_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
pretrained_name: name of a pretrained AudioToAudioModel model (from NGC registry)
audio_dir: path to directory with audio files
dataset_manifest: path to dataset JSON manifest file (in NeMo format)
max_utts: maximum number of utterances to process
input_channel_selector: list of channels to take from audio files, defaults to `None` and takes all available channels
input_key: key for audio filepath in the manifest file, defaults to `audio_filepath`
Expand Down Expand Up @@ -80,6 +81,7 @@ class ProcessConfig:
pretrained_name: Optional[str] = None # Name of a pretrained model
audio_dir: Optional[str] = None # Path to a directory which contains audio files
dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest
max_utts: Optional[int] = None # max number of utterances to process

# Audio configs
input_channel_selector: Optional[List] = None # Union types not supported Optional[Union[List, int]]
Expand Down Expand Up @@ -171,6 +173,10 @@ def main(cfg: ProcessConfig) -> ProcessConfig:
audio_file = manifest_dir / audio_file
filepaths.append(str(audio_file.absolute()))

if cfg.max_utts is not None:
# Limit the number of utterances to process
filepaths = filepaths[: cfg.max_utts]

logging.info(f"\nProcessing {len(filepaths)} files...\n")

# setup AMP (optional)
Expand Down Expand Up @@ -225,6 +231,9 @@ def autocast():
item = json.loads(line)
item['processed_audio_filepath'] = paths2processed_files[idx]
f.write(json.dumps(item) + "\n")

if cfg.max_utts is not None and idx >= cfg.max_utts - 1:
break
else:
for idx, processed_file in enumerate(paths2processed_files):
item = {'processed_audio_filepath': processed_file}
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/data/audio_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ def get_duration(audio_files: List[str]) -> List[float]:
Returns:
List of durations in seconds.
"""
duration = [librosa.get_duration(filename=f) for f in flatten(audio_files)]
duration = [librosa.get_duration(path=f) for f in flatten(audio_files)]
return duration

def load_embedding(self, example: collections.Audio.OUTPUT_TYPE) -> Dict[str, torch.Tensor]:
Expand Down
8 changes: 4 additions & 4 deletions nemo/collections/asr/losses/audio_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def convolution_invariant_target(
input_length: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
filter_length: int = 512,
diag_reg: float = 1e-8,
eps: float = 1e-10,
diag_reg: float = 1e-6,
eps: float = 1e-8,
) -> torch.Tensor:
"""Calculate optimal convolution-invariant target for a given estimate.
Assumes time dimension is the last dimension in the array.
Expand Down Expand Up @@ -222,7 +222,7 @@ def calculate_sdr_batch(
convolution_filter_length: Optional[int] = 512,
remove_mean: bool = True,
sdr_max: Optional[float] = None,
eps: float = 1e-10,
eps: float = 1e-8,
) -> torch.Tensor:
"""Calculate signal-to-distortion ratio per channel.
Expand Down Expand Up @@ -310,7 +310,7 @@ def __init__(
convolution_filter_length: Optional[int] = 512,
remove_mean: bool = True,
sdr_max: Optional[float] = None,
eps: float = 1e-10,
eps: float = 1e-8,
):
super().__init__()

Expand Down
44 changes: 27 additions & 17 deletions nemo/collections/asr/models/audio_to_audio_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import hydra
import torch
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer

from nemo.collections.asr.metrics.audio import AudioMetricWrapper
Expand Down Expand Up @@ -86,16 +86,19 @@ def _setup_metrics(self, tag: str = 'val'):
# Setup metrics for each dataloader
self.metrics[tag] = torch.nn.ModuleList()
for dataloader_idx in range(num_dataloaders):
metrics_dataloader_idx = torch.nn.ModuleDict(
{
name: AudioMetricWrapper(
metric=hydra.utils.instantiate(cfg),
channel=cfg.get('channel'),
metric_using_batch_averaging=cfg.get('metric_using_batch_averaging'),
)
for name, cfg in metrics_cfg.items()
}
)
metrics_dataloader_idx = {}
for name, cfg in metrics_cfg.items():
logging.debug('Initialize %s for dataloader_idx %s', name, dataloader_idx)
cfg_dict = OmegaConf.to_container(cfg)
cfg_channel = cfg_dict.pop('channel', None)
cfg_batch_averaging = cfg_dict.pop('metric_using_batch_averaging', None)
metrics_dataloader_idx[name] = AudioMetricWrapper(
metric=hydra.utils.instantiate(cfg_dict),
channel=cfg_channel,
metric_using_batch_averaging=cfg_batch_averaging,
)

metrics_dataloader_idx = torch.nn.ModuleDict(metrics_dataloader_idx)
self.metrics[tag].append(metrics_dataloader_idx.to(self.device))

logging.info(
Expand All @@ -115,15 +118,24 @@ def on_test_start(self):
return super().on_test_start()

def validation_step(self, batch, batch_idx, dataloader_idx: int = 0):
return self.evaluation_step(batch, batch_idx, dataloader_idx, 'val')
output_dict = self.evaluation_step(batch, batch_idx, dataloader_idx, 'val')
if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(output_dict)
else:
self.validation_step_outputs.append(output_dict)
return output_dict

def test_step(self, batch, batch_idx, dataloader_idx=0):
return self.evaluation_step(batch, batch_idx, dataloader_idx, 'test')
output_dict = self.evaluation_step(batch, batch_idx, dataloader_idx, 'test')
if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(output_dict)
else:
self.test_step_outputs.append(output_dict)
return output_dict

def multi_evaluation_epoch_end(self, outputs, dataloader_idx: int = 0, tag: str = 'val'):
# Handle loss
loss_mean = torch.stack([x[f'{tag}_loss'] for x in outputs]).mean()
output_dict = {f'{tag}_loss': loss_mean}
tensorboard_logs = {f'{tag}_loss': loss_mean}

# Handle metrics for this tag and dataloader_idx
Expand All @@ -135,9 +147,7 @@ def multi_evaluation_epoch_end(self, outputs, dataloader_idx: int = 0, tag: str
# Store for logs
tensorboard_logs[f'{tag}_{name}'] = value

output_dict['log'] = tensorboard_logs

return output_dict
return {f'{tag}_loss': loss_mean, 'log': tensorboard_logs}

def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
return self.multi_evaluation_epoch_end(outputs, dataloader_idx, 'val')
Expand Down
35 changes: 26 additions & 9 deletions nemo/collections/asr/models/enhancement_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,24 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.decoder = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.decoder)

if 'mixture_consistency' in self._cfg:
logging.debug('Using mixture consistency')
self.mixture_consistency = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mixture_consistency)
else:
logging.debug('Mixture consistency not used')
self.mixture_consistency = None

# Future enhancement:
# If subclasses need to modify the config before calling super()
# Check ASRBPE* classes do with their mixin

# Setup augmentation
if hasattr(self._cfg, 'channel_augment') and self._cfg.channel_augment is not None:
logging.debug('Using channel augmentation')
self.channel_augmentation = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.channel_augment)
else:
logging.debug('Channel augmentation not used')
self.channel_augmentation = None

# Setup optional Optimization flags
self.setup_optimization_flags()

Expand Down Expand Up @@ -125,7 +135,7 @@ def process(
temporary_manifest_filepath = os.path.join(tmpdir, 'manifest.json')
with open(temporary_manifest_filepath, 'w', encoding='utf-8') as fp:
for audio_file in paths2audio_files:
entry = {'input_filepath': audio_file, 'duration': librosa.get_duration(filename=audio_file)}
entry = {'input_filepath': audio_file, 'duration': librosa.get_duration(path=audio_file)}
fp.write(json.dumps(entry) + '\n')

config = {
Expand Down Expand Up @@ -397,17 +407,25 @@ def training_step(self, batch, batch_idx):
if target_signal.ndim == 2:
target_signal = target_signal.unsqueeze(1)

# Apply channel augmentation
if self.training and self.channel_augmentation is not None:
input_signal = self.channel_augmentation(input=input_signal)

# Process input
processed_signal, _ = self.forward(input_signal=input_signal, input_length=input_length)

loss_value = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length)
# Calculate the loss
loss = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length)

# Prepare logs
tensorboard_logs = {
'train_loss': loss_value,
'train_loss': loss,
'learning_rate': self._optimizer.param_groups[0]['lr'],
'global_step': torch.tensor(self.trainer.global_step, dtype=torch.float32),
}

return {'loss': loss_value, 'log': tensorboard_logs}
self.log_dict(tensorboard_logs, on_step=True, sync_dist=True)
return loss

def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'):
input_signal, input_length, target_signal, target_length = batch
Expand All @@ -419,11 +437,11 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str =
if target_signal.ndim == 2:
target_signal = target_signal.unsqueeze(1)

# Process input
processed_signal, _ = self.forward(input_signal=input_signal, input_length=input_length)

# Prepare output
loss_value = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length)
output_dict = {f'{tag}_loss': loss_value}
# Calculate the loss
loss = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length)

# Update metrics
if hasattr(self, 'metrics') and tag in self.metrics:
Expand All @@ -433,8 +451,7 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str =

# Log global step
self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32), sync_dist=True)

return output_dict
return {f'{tag}_loss': loss}

@classmethod
def list_available_models(cls) -> Optional[PretrainedModelInfo]:
Expand Down
Loading

0 comments on commit 1cb25f8

Please sign in to comment.