Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nateanl committed May 20, 2022
1 parent 11b59c7 commit a8219e1
Showing 1 changed file with 76 additions and 88 deletions.
164 changes: 76 additions & 88 deletions examples/tutorials/mvdr_tutorial.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
MVDR Beamforming with torchaudio
MVDR Beamforming with TorchAudio
================================
**Author** `Zhaoheng Ni <zni@fb.com>`__
Expand All @@ -11,20 +11,19 @@
# 1. Overview
# -----------
#
# This is a tutorial on how to apply Minimum Variance Distortionless
# Response (MVDR) beamforming to estimate the enhanced speech with
# torchaudio.
# This is a tutorial on applying Minimum Variance Distortionless
# Response (MVDR) beamforming to estimate enhanced speech with
# TorchAudio.
#
# Steps
# Steps:
#
# - Ideal Ratio Mask (IRM) is generated by dividing the clean/noise
# - Generate an ideal ratio mask (IRM) by dividing the clean/noise
# magnitude by the mixture magnitude.
# - Power spectral density (PSD) matrices are estimated by
# :py:func:`torchaudio.transforms.PSD`.
# - The enhanced speech is estimated by using the two MVDR modules in
# torchaudio (:py:func:`torchaudio.transforms.SoudenMVDR` and
# - Estimate power spectral density (PSD) matrices using :py:func:`torchaudio.transforms.PSD`.
# - Estimate enhanced speech using MVDR modules
# (:py:func:`torchaudio.transforms.SoudenMVDR` and
# :py:func:`torchaudio.transforms.RTFMVDR`).
# - We benchmark the two methods
# - Benchmark the two methods
# (:py:func:`torchaudio.functional.rtf_evd` and
# :py:func:`torchaudio.functional.rtf_power`) for computing the
# relative transfer function (RTF) matrix of the reference microphone.
Expand Down Expand Up @@ -52,7 +51,7 @@
#
# ``SSB07200001\#noise-sound-bible-0038\#7.86_6.16_3.00_3.14_4.84_134.5285_191.7899_0.4735\#15217\#25.16333303751458\#0.2101221178590021.wav``
#
# which was generated with:
# , which was generated with:
#
# - ``SSB07200001.wav`` from
# `AISHELL-3 <https://www.openslr.org/93/>`__ (Apache License
Expand All @@ -78,11 +77,11 @@
#


def plot_specgram(stft, title="Spectrogram", xlim=None):
stft = stft.abs()
stft = 20 * torch.log10(stft + 1e-8).numpy()
def plot_spectrogram(stft, title="Spectrogram", xlim=None):
magnitude = stft.abs()
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
figure, axis = plt.subplots(1, 1)
img = axis.imshow(stft, cmap="nipy_spectral", vmin=-100, vmax=0, origin="lower", aspect="auto")
img = axis.imshow(spectrogram, cmap="nipy_spectral", vmin=-100, vmax=0, origin="lower", aspect="auto")
figure.suptitle(title)
plt.colorbar(img, ax=axis)
plt.show()
Expand Down Expand Up @@ -118,8 +117,8 @@ def si_snr(estimate, reference, epsilon=1e-8):


######################################################################
# 3. Generate the Ideal Ratio Mask (IRM)
# --------------------------------------
# 3. Generate Ideal Ratio Masks (IRMs)
# ------------------------------------
#


Expand All @@ -136,8 +135,8 @@ def si_snr(estimate, reference, epsilon=1e-8):


######################################################################
# Note: To imrove the robustness of computation, it is recommended to use the
# double precision (``torch.float64`` or ``torch.double`` for the waveforms.
# Note: To improve computational robustness, it is recommended to represent
# the waveforms as double-precision floating point (``torch.float64`` or ``torch.double``) values.
#

waveform_mix = waveform_mix.to(torch.double)
Expand All @@ -146,8 +145,8 @@ def si_snr(estimate, reference, epsilon=1e-8):


######################################################################
# 3.2. Compute complex-valued Spectrums
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 3.2. Compute STFT coefficients
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#

N_FFT = 1024
Expand All @@ -159,17 +158,17 @@ def si_snr(estimate, reference, epsilon=1e-8):
)
istft = torchaudio.transforms.InverseSpectrogram(n_fft=N_FFT, hop_length=N_HOP)

spectrum_mix = stft(waveform_mix)
spectrum_clean = stft(waveform_clean)
spectrum_noise = stft(waveform_noise)
stft_mix = stft(waveform_mix)
stft_clean = stft(waveform_clean)
stft_noise = stft(waveform_noise)


######################################################################
# 3.2.1. Visualize mixture speech
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#

plot_specgram(spectrum_mix[0], "Spectrogram of Mixture Speech (dB)")
plot_spectrogram(stft_mix[0], "Spectrogram of Mixture Speech (dB)")
Audio(waveform_mix[0], rate=SAMPLE_RATE)


Expand All @@ -178,7 +177,7 @@ def si_snr(estimate, reference, epsilon=1e-8):
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#

plot_specgram(spectrum_clean[0], "Spectrogram of Clean Speech (dB)")
plot_spectrogram(stft_clean[0], "Spectrogram of Clean Speech (dB)")
Audio(waveform_clean[0], rate=SAMPLE_RATE)


Expand All @@ -187,38 +186,38 @@ def si_snr(estimate, reference, epsilon=1e-8):
# ^^^^^^^^^^^^^^^^^^^^^^
#

plot_specgram(spectrum_noise[0], "Spectrogram of Noise (dB)")
plot_spectrogram(stft_noise[0], "Spectrogram of Noise (dB)")
Audio(waveform_noise[0], rate=SAMPLE_RATE)


######################################################################
# 3.3. Define the reference microphone
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# We choose the first microphone in the array as the reference channel for demonstration.
# The selection of the reference channel may depend on the design of the microphone array.
#
# You can also apply a neural network to estimate the reference channel and
# pass it to the MVDR module in an end-to-end speech enhancement model.
# You can also apply an end-to-end neural network which estimates both the reference channel and
# the PSD matrices, then obtains the enhanced STFT coefficients by the MVDR module.

REFERENCE_CHANNEL = 0


######################################################################
# 3.4. Compute IRMs for target speech and noise
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 3.4. Compute IRMs
# ~~~~~~~~~~~~~~~~~
#


def get_irms(spec_clean, spec_noise):
mag_clean = spec_clean.abs() ** 2
mag_noise = spec_noise.abs() ** 2
def get_irms(stft_clean, stft_noise):
mag_clean = stft_clean.abs() ** 2
mag_noise = stft_noise.abs() ** 2
irm_speech = mag_clean / (mag_clean + mag_noise)
irm_noise = mag_noise / (mag_clean + mag_noise)
return irm_speech[REFERENCE_CHANNEL], irm_noise[REFERENCE_CHANNEL]


irm_speech, irm_noise = get_irms(spectrum_clean, spectrum_noise)
irm_speech, irm_noise = get_irms(stft_clean, stft_noise)


######################################################################
Expand All @@ -236,131 +235,120 @@ def get_irms(spec_clean, spec_noise):

plot_mask(irm_noise, "IRM of the Noise")


######################################################################
# 4. Beamforming using SoudenMVDR
# -------------------------------
#


######################################################################
# 4.1. Compute PSD matrices
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 4. Compute PSD matrices
# -----------------------
#
# :py:func:`torchaudio.transforms.PSD` computes the time-invariant PSD matrix given
# the multi-channel complex-valued spectrum and the Time-Frequency mask.
# the multi-channel complex-valued STFT coefficients of the mixture speech
# and the time-frequency mask.
#
# The shape of the PSD matrix is `(..., freq, channel, channel)`.

psd_transform = torchaudio.transforms.PSD()

psd_speech = psd_transform(spectrum_mix, irm_speech)
psd_noise = psd_transform(spectrum_mix, irm_noise)
psd_speech = psd_transform(stft_mix, irm_speech)
psd_noise = psd_transform(stft_mix, irm_noise)


######################################################################
# 5. Beamforming using SoudenMVDR
# -------------------------------
#


######################################################################
# 4.2. Apply Beamforming
# 5.1. Apply beamforming
# ~~~~~~~~~~~~~~~~~~~~~~
#
# :py:func:`torchaudio.transforms.SoudenMVDR` takes the multi-channel
# complexed-valued spectrum of the mixture speech, PSD matrices of
# target speech and noise, and the reference channel as the inputs.
# complexed-valued STFT coefficients of the mixture speech, PSD matrices of
# target speech and noise, and the reference channel inputs.
#
# The output is a single-channel complex-valued spectrum of the enhanced speech.
# Then we can obtain the enhanced wavefrom by passing it to the
# The output is a single-channel complex-valued STFT coefficients of the enhanced speech.
# We can then obtain the enhanced waveform by passing this output to the
# :py:func:`torchaudio.transforms.InverseSpectrogram` module.

mvdr_transform = torchaudio.transforms.SoudenMVDR()
spectrum_souden = mvdr_transform(spectrum_mix, psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_souden = istft(spectrum_souden, length=waveform_mix.shape[-1])
stft_souden = mvdr_transform(stft_mix, psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_souden = istft(stft_souden, length=waveform_mix.shape[-1])


######################################################################
# 4.3. Result for SoudenMVDR
# 5.2. Result for SoudenMVDR
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
#

plot_specgram(spectrum_souden, "Enhanced Spectrogram by SoudenMVDR (dB)")
plot_spectrogram(stft_souden, "Enhanced Spectrogram by SoudenMVDR (dB)")
waveform_souden = waveform_souden.reshape(1, -1)
print(f"Si-SNR score: {si_snr(waveform_souden, waveform_clean[0:1])}")
Audio(waveform_souden, rate=SAMPLE_RATE)


######################################################################
# 5. Beamforming using RTFMVDR
# 6. Beamforming using RTFMVDR
# ----------------------------
#


######################################################################
# 5.1. Compute PSD matrices
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#

psd_transform = torchaudio.transforms.PSD()

psd_speech = psd_transform(spectrum_mix, irm_speech)
psd_noise = psd_transform(spectrum_mix, irm_noise)


######################################################################
# 5.2. Compute RTF
# 6.1. Compute RTF
# ~~~~~~~~~~~~~~~~
#
# There are two methods in torchaudio to compute the RTF matrix of the
# TorchAudio offers two methods for computing the RTF matrix of a
# target speech:
#
# :py:func:`torchaudio.functional.rtf_evd`, which applies eigenvalue
# - :py:func:`torchaudio.functional.rtf_evd`, which applies eigenvalue
# decomposition to the PSD matrix of target speech to get the RTF matrix.
#
# :py:func:`torchaudio.functional.rtf_power`, which applies the power iteration
# method. You can tune the number of iterations by changing ``n_iter`` argument.
# - :py:func:`torchaudio.functional.rtf_power`, which applies the power iteration
# method. You can specify the number of iterations with argument ``n_iter``.
#

rtf_evd = F.rtf_evd(psd_speech)
rtf_power = F.rtf_power(psd_speech, psd_noise, reference_channel=REFERENCE_CHANNEL)


######################################################################
# 5.3. Apply Beamforming
# 6.2. Apply beamforming
# ~~~~~~~~~~~~~~~~~~~~~~
#
# :py:func:`torchaudio.transforms.RTFMVDR` takes the multi-channel
# complexed-valued spectrum of the mixture speech, RTF matrix of target speech,
# PSD matrix of noise, and the reference channel as the inputs.
# complexed-valued STFT coefficients of the mixture speech, RTF matrix of target speech,
# PSD matrix of noise, and the reference channel inputs.
#
# The output is a single-channel complex-valued spectrum of the enhanced speech.
# Then we can obtain the enhanced wavefrom by passing it to the
# The output is a single-channel complex-valued STFT coefficients of the enhanced speech.
# We can then obtain the enhanced waveform by passing this output to the
# :py:func:`torchaudio.transforms.InverseSpectrogram` module.

mvdr_transform = torchaudio.transforms.RTFMVDR()

# compute the enhanced speech based on F.rtf_evd
spectrum_rtf_evd = mvdr_transform(spectrum_mix, rtf_evd, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_rtf_evd = istft(spectrum_rtf_evd, length=waveform_mix.shape[-1])
stft_rtf_evd = mvdr_transform(stft_mix, rtf_evd, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_rtf_evd = istft(stft_rtf_evd, length=waveform_mix.shape[-1])

# compute the enhanced speech based on F.rtf_power
spectrum_rtf_power = mvdr_transform(spectrum_mix, rtf_power, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_rtf_power = istft(spectrum_rtf_power, length=waveform_mix.shape[-1])
stft_rtf_power = mvdr_transform(stft_mix, rtf_power, psd_noise, reference_channel=REFERENCE_CHANNEL)
waveform_rtf_power = istft(stft_rtf_power, length=waveform_mix.shape[-1])


######################################################################
# 5.4. Result for RTFMVDR with `rtf_evd`
# 6.3. Result for RTFMVDR with `rtf_evd`
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#

plot_specgram(spectrum_rtf_evd, "Enhanced Spectrogram by RTFMVDR and F.rtf_evd (dB)")
plot_spectrogram(stft_rtf_evd, "Enhanced Spectrogram by RTFMVDR and F.rtf_evd (dB)")
waveform_rtf_evd = waveform_rtf_evd.reshape(1, -1)
print(f"Si-SNR score: {si_snr(waveform_rtf_evd, waveform_clean[0:1])}")
Audio(waveform_rtf_evd, rate=SAMPLE_RATE)


######################################################################
# 5.5. Result for RTFMVDR with `rtf_power`
# 6.4. Result for RTFMVDR with `rtf_power`
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#

plot_specgram(spectrum_rtf_power, "Enhanced Spectrogram by RTFMVDR and F.rtf_evd (dB)")
plot_spectrogram(stft_rtf_power, "Enhanced Spectrogram by RTFMVDR and F.rtf_evd (dB)")
waveform_rtf_power = waveform_rtf_power.reshape(1, -1)
print(f"Si-SNR score: {si_snr(waveform_rtf_power, waveform_clean[0:1])}")
Audio(waveform_rtf_power, rate=SAMPLE_RATE)

0 comments on commit a8219e1

Please sign in to comment.