-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding a mask estimator which can process an arbitrary number of chan…
…nels Signed-off-by: Ante Jukić <ajukic@nvidia.com>
- Loading branch information
Showing
12 changed files
with
1,698 additions
and
109 deletions.
There are no files selected for viewing
146 changes: 146 additions & 0 deletions
146
examples/audio_tasks/conf/beamforming_flex_channels.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# This configuration contains the exemplary values for training a multichannel speech enhancement model with a mask-based beamformer. | ||
# | ||
name: beamforming_flex_channels | ||
|
||
model: | ||
sample_rate: 16000 | ||
skip_nan_grad: false | ||
num_outputs: 1 | ||
|
||
train_ds: | ||
manifest_filepath: ??? | ||
input_key: audio_filepath # key of the input signal path in the manifest | ||
input_channel_selector: null # load all channels from the input file | ||
target_key: target_anechoic_filepath # key of the target signal path in the manifest | ||
target_channel_selector: 0 # load only the first channel from the target file | ||
audio_duration: 4.0 # in seconds, audio segment duration for training | ||
random_offset: true # if the file is longer than audio_duration, use random offset to select a subsegment | ||
min_duration: ${model.train_ds.audio_duration} | ||
batch_size: 16 # batch size may be increased based on the available memory | ||
shuffle: true | ||
num_workers: 16 | ||
pin_memory: true | ||
|
||
validation_ds: | ||
manifest_filepath: ??? | ||
input_key: audio_filepath # key of the input signal path in the manifest | ||
input_channel_selector: null # load all channels from the input file | ||
target_key: target_anechoic_filepath # key of the target signal path in the manifest | ||
target_channel_selector: 0 # load only the first channel from the target file | ||
batch_size: 8 | ||
shuffle: false | ||
num_workers: 8 | ||
pin_memory: true | ||
|
||
channel_augment: | ||
_target_: nemo.collections.asr.parts.submodules.multichannel_modules.ChannelAugment | ||
num_channels_min: 2 # minimal number of channels selected for each batch | ||
num_channels_max: null # max number of channels is determined by the batch size | ||
permute_channels: true | ||
|
||
encoder: | ||
_target_: nemo.collections.asr.modules.audio_preprocessing.AudioToSpectrogram | ||
fft_length: 512 # Length of the window and FFT for calculating spectrogram | ||
hop_length: 256 # Hop length for calculating spectrogram | ||
|
||
decoder: | ||
_target_: nemo.collections.asr.modules.audio_preprocessing.SpectrogramToAudio | ||
fft_length: ${model.encoder.fft_length} | ||
hop_length: ${model.encoder.hop_length} | ||
|
||
mask_estimator: | ||
_target_: nemo.collections.asr.modules.audio_modules.MaskEstimatorFlexChannels | ||
num_outputs: ${model.num_outputs} # number of output masks | ||
num_subbands: 257 # number of subbands for the input spectrogram | ||
num_blocks: 5 # number of blocks in the model | ||
channel_reduction_position: 3 # 0-indexed, apply channel reduction before this block | ||
channel_reduction_type: average # channel-wise reduction | ||
channel_block_type: transform_average_concatenate # channel block | ||
temporal_block_type: conformer_encoder # temporal block | ||
temporal_block_num_layers: 5 # number of layers for the temporal block | ||
temporal_block_num_heads: 4 # number of heads for the temporal block | ||
temporal_block_dimension: 128 # the hidden size of the temporal block | ||
mag_reduction: null # channel-wise reduction of magnitude | ||
mag_normalization: mean_var # normalization using mean and variance | ||
use_ipd: true # use inter-channel phase difference | ||
ipd_normalization: mean # mean normalization | ||
|
||
mask_processor: | ||
# Mask-based multi-channel processor | ||
_target_: nemo.collections.asr.modules.audio_modules.MaskBasedBeamformer | ||
filter_type: pmwf # parametric multichannel wiener filter | ||
filter_beta: 0.0 # mvdr | ||
filter_rank: one | ||
ref_channel: max_snr # select reference channel by maximizing estimated SNR | ||
ref_hard: 1 # a one-hot reference. If false, a soft estimate across channels is used. | ||
ref_hard_use_grad: false # use straight-through gradient when using hard reference | ||
ref_subband_weighting: false # use subband weighting for reference estimation | ||
num_subbands: ${model.mask_estimator.num_subbands} | ||
|
||
loss: | ||
_target_: nemo.collections.asr.losses.SDRLoss | ||
convolution_invariant: true # convolution-invariant loss | ||
sdr_max: 30 # soft threshold for SDR | ||
|
||
metrics: | ||
val: | ||
sdr_0: | ||
_target_: torchmetrics.audio.SignalDistortionRatio | ||
channel: 0 # evaluate only on channel 0, if there are multiple outputs | ||
|
||
optim: | ||
name: adamw | ||
lr: 1e-4 | ||
# optimizer arguments | ||
betas: [0.9, 0.98] | ||
weight_decay: 1e-3 | ||
|
||
# scheduler setup | ||
sched: | ||
name: CosineAnnealing | ||
# scheduler config override | ||
warmup_steps: 10000 | ||
warmup_ratio: null | ||
min_lr: 1e-6 | ||
|
||
trainer: | ||
devices: -1 # number of GPUs, -1 would use all available GPUs | ||
num_nodes: 1 | ||
max_epochs: -1 | ||
max_steps: -1 # computed at runtime if not set | ||
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations | ||
accelerator: auto | ||
strategy: ddp | ||
accumulate_grad_batches: 1 | ||
gradient_clip_val: null | ||
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. | ||
log_every_n_steps: 25 # Interval of logging. | ||
enable_progress_bar: true | ||
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it | ||
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs | ||
sync_batchnorm: true | ||
enable_checkpointing: False # Provided by exp_manager | ||
logger: false # Provided by exp_manager | ||
|
||
exp_manager: | ||
exp_dir: null | ||
name: ${name} | ||
create_tensorboard_logger: true | ||
create_checkpoint_callback: true | ||
checkpoint_callback_params: | ||
# in case of multiple validation sets, first one is used | ||
monitor: "val_loss" | ||
mode: "min" | ||
save_top_k: 5 | ||
always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints | ||
|
||
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.pyth | ||
# you need to set these two to true to continue the training | ||
resume_if_exists: false | ||
resume_ignore_no_checkpoint: false | ||
|
||
# You may use this section to create a W&B logger | ||
create_wandb_logger: false | ||
wandb_logger_kwargs: | ||
name: null | ||
project: null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.