Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nateanl committed May 23, 2022
1 parent a8219e1 commit 12fc2ed
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions examples/tutorials/mvdr_tutorial.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
MVDR Beamforming with TorchAudio
Speech Enhancement with MVDR Beamforming
================================
**Author** `Zhaoheng Ni <zni@fb.com>`__
Expand Down Expand Up @@ -67,8 +67,8 @@
from torchaudio.utils import download_asset

SAMPLE_RATE = 16000
SAMPLE_MIX = download_asset("tutorial-assets/mvdr/mix.wav")
SAMPLE_CLEAN = download_asset("tutorial-assets/mvdr/reverb_clean.wav")
SAMPLE_CLEAN = download_asset("tutorial-assets/mvdr/clean.wav")
SAMPLE_NOISE = download_asset("tutorial-assets/mvdr/noise.wav")


######################################################################
Expand All @@ -81,7 +81,7 @@ def plot_spectrogram(stft, title="Spectrogram", xlim=None):
magnitude = stft.abs()
spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
figure, axis = plt.subplots(1, 1)
img = axis.imshow(spectrogram, cmap="nipy_spectral", vmin=-100, vmax=0, origin="lower", aspect="auto")
img = axis.imshow(spectrogram, cmap="viridis", vmin=-100, vmax=0, origin="lower", aspect="auto")
figure.suptitle(title)
plt.colorbar(img, ax=axis)
plt.show()
Expand All @@ -90,7 +90,7 @@ def plot_spectrogram(stft, title="Spectrogram", xlim=None):
def plot_mask(mask, title="Mask", xlim=None):
mask = mask.numpy()
figure, axis = plt.subplots(1, 1)
img = axis.imshow(mask, cmap="jet", origin="lower", aspect="auto")
img = axis.imshow(mask, cmap="viridis", origin="lower", aspect="auto")
figure.suptitle(title)
plt.colorbar(img, ax=axis)
plt.show()
Expand Down Expand Up @@ -127,11 +127,11 @@ def si_snr(estimate, reference, epsilon=1e-8):
# ~~~~~~~~~~~~~~~~~~~~
#

waveform_mix, sr = torchaudio.load(SAMPLE_MIX)
waveform_clean, sr2 = torchaudio.load(SAMPLE_CLEAN)
waveform_clean, sr = torchaudio.load(SAMPLE_CLEAN)
waveform_noise, sr2 = torchaudio.load(SAMPLE_NOISE)
assert sr == sr2 == SAMPLE_RATE
# The mixture waveform is a combination of clean and noise waveforms
waveform_noise = waveform_mix - waveform_clean
waveform_mix = waveform_clean + waveform_noise


######################################################################
Expand Down

0 comments on commit 12fc2ed

Please sign in to comment.