From 81c0870abf58533d84a822401a3c2c750b2e8a56 Mon Sep 17 00:00:00 2001 From: Zhaoheng Ni Date: Wed, 9 Feb 2022 16:47:21 +0000 Subject: [PATCH] add unittests --- .../functional/autograd_impl.py | 59 ++++++++++++++++++- .../functional/batch_consistency_test.py | 56 ++++++++++++++++++ .../torchscript_consistency_impl.py | 53 +++++++++++++++++ 3 files changed, 167 insertions(+), 1 deletion(-) diff --git a/test/torchaudio_unittest/functional/autograd_impl.py b/test/torchaudio_unittest/functional/autograd_impl.py index 6d942d1e92f..695c9a5d985 100644 --- a/test/torchaudio_unittest/functional/autograd_impl.py +++ b/test/torchaudio_unittest/functional/autograd_impl.py @@ -1,5 +1,6 @@ from functools import partial from typing import Callable, Tuple +from unittest import expectedFailure import torch import torchaudio.functional as F @@ -9,6 +10,7 @@ from torchaudio_unittest.common_utils import ( TestBaseMixin, get_whitenoise, + get_spectrogram, rnnt_utils, ) @@ -24,7 +26,7 @@ def assert_grad( inputs_ = [] for i in inputs: if torch.is_tensor(i): - i = i.to(dtype=self.dtype, device=self.device) + i = i.to(dtype=torch.cdouble if i.is_complex() else self.dtype, device=self.device) if enable_all_grad: i.requires_grad = True inputs_.append(i) @@ -250,6 +252,61 @@ def test_bandreject_biquad(self, central_freq, Q): Q = torch.tensor(Q) self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q)) + @parameterized.expand( + [ + (True,), + (False,), + ] + ) + def test_compute_power_spectral_density_matrix(self, use_mask): + torch.random.manual_seed(2434) + sr = 8000 + x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=6) + specgram = get_spectrogram(x, n_fft=400, hop_length=100) + if use_mask: + mask = torch.rand(specgram.size()) + else: + mask = None + self.assert_grad(F.compute_power_spectral_density_matrix, (specgram, mask)) + + def test_compute_mvdr_weights_souden(self): + torch.random.manual_seed(2434) + channel = 4 + psd_speech = torch.rand(129, channel, channel, dtype=torch.cfloat) + psd_noise = torch.rand(129, channel, channel, dtype=torch.cfloat) + self.assert_grad(F.compute_mvdr_weights_souden, (psd_speech, psd_noise, 0)) + + def test_compute_mvdr_weights_rtf(self): + torch.random.manual_seed(2434) + channel = 4 + rtf = torch.rand(129, channel, 1, dtype=torch.cfloat) + psd_noise = torch.rand(129, channel, channel, dtype=torch.cfloat) + self.assert_grad(F.compute_mvdr_weights_rtf, (rtf, psd_noise, 0)) + + # The eigenvector can be different in different runs, expected to fail + @expectedFailure + def test_compute_compute_rtf_evd(self): + torch.random.manual_seed(2434) + channel = 4 + specgram = torch.rand(channel, 201, 100, dtype=torch.cfloat) + psd_speech = F.compute_power_spectral_density_matrix(specgram) + self.assert_grad(F.compute_rtf_evd, (psd_speech,)) + + def test_compute_compute_rtf_power(self): + torch.random.manual_seed(2434) + channel = 4 + psd_speech = torch.rand(129, channel, channel, dtype=torch.cfloat) + psd_noise = torch.rand(129, channel, channel, dtype=torch.cfloat) + self.assert_grad(F.compute_rtf_power, (psd_speech, psd_noise, 0)) + + def test_apply_beamforming(self): + torch.random.manual_seed(2434) + sr = 8000 + x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=6) + specgram = get_spectrogram(x, n_fft=400, hop_length=100) + beamform_weights = torch.rand(201, 6, dtype=torch.cfloat) + self.assert_grad(F.apply_beamforming, (beamform_weights, specgram)) + class AutogradFloat32(TestBaseMixin): def assert_grad( diff --git a/test/torchaudio_unittest/functional/batch_consistency_test.py b/test/torchaudio_unittest/functional/batch_consistency_test.py index 5beac512a6c..f459a26e30a 100644 --- a/test/torchaudio_unittest/functional/batch_consistency_test.py +++ b/test/torchaudio_unittest/functional/batch_consistency_test.py @@ -241,3 +241,59 @@ def test_filtfilt(self): itemwise_output = torch.stack([F.filtfilt(x[i], a[i], b[i]) for i in range(self.batch_size)]) self.assertEqual(batchwise_output, itemwise_output) + + def test_compute_power_spectral_density_matrix(self): + sample_rate = 44100 + waveform = common_utils.get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=6) + specgram = common_utils.get_spectrogram(waveform, n_fft=400, hop_length=100) + specgram = specgram.view(2, 3, specgram.size(-2), specgram.size(-1)) + batchwise_output = F.compute_power_spectral_density_matrix(specgram) + itemwise_output = torch.stack([F.compute_power_spectral_density_matrix(specgram[i]) for i in range(2)]) + self.assertEqual(batchwise_output, itemwise_output) + + def test_compute_mvdr_weights_souden(self): + torch.random.manual_seed(2434) + channel = 4 + psd_speech = torch.rand(2, 129, channel, channel, dtype=torch.cfloat) + psd_noise = torch.rand(2, 129, channel, channel, dtype=torch.cfloat) + batchwise_output = F.compute_mvdr_weights_souden(psd_speech, psd_noise, 0) + itemwise_output = torch.stack([F.compute_mvdr_weights_souden(psd_speech[i], psd_noise[i], 0) for i in range(2)]) + self.assertEqual(batchwise_output, itemwise_output) + + def test_compute_mvdr_weights_rtf(self): + torch.random.manual_seed(2434) + channel = 4 + rtf = torch.rand(2, 129, channel, 1, dtype=torch.cfloat) + psd_noise = torch.rand(2, 129, channel, channel, dtype=torch.cfloat) + batchwise_output = F.compute_mvdr_weights_rtf(rtf, psd_noise, 0) + itemwise_output = torch.stack([F.compute_mvdr_weights_rtf(rtf[i], psd_noise[i], 0) for i in range(2)]) + self.assertEqual(batchwise_output, itemwise_output) + + def test_compute_compute_rtf_evd(self): + torch.random.manual_seed(2434) + channel = 4 + specgram = torch.rand(2, channel, 201, 100, dtype=torch.cfloat) + psd_speech = F.compute_power_spectral_density_matrix(specgram) + batchwise_output = F.compute_rtf_evd(psd_speech) + itemwise_output = torch.stack([F.compute_rtf_evd(psd_speech[i]) for i in range(2)]) + self.assertEqual(batchwise_output, itemwise_output) + + def test_compute_compute_rtf_power(self): + torch.random.manual_seed(2434) + channel = 4 + psd_speech = torch.rand(2, 129, channel, channel, dtype=torch.cfloat) + psd_noise = torch.rand(2, 129, channel, channel, dtype=torch.cfloat) + batchwise_output = F.compute_rtf_power(psd_speech, psd_noise, 0) + itemwise_output = torch.stack([F.compute_rtf_power(psd_speech[i], psd_noise[i], 0) for i in range(2)]) + self.assertEqual(batchwise_output, itemwise_output) + + def test_apply_beamforming(self): + torch.random.manual_seed(2434) + sr = 8000 + x = common_utils.get_whitenoise(sample_rate=sr, duration=0.05, n_channels=6) + specgram = common_utils.get_spectrogram(x, n_fft=400, hop_length=100) + specgram = specgram.view(2, 3, specgram.size(-2), specgram.size(-1)) + beamform_weights = torch.rand(2, 201, 3, dtype=torch.cfloat) + batchwise_output = F.apply_beamforming(beamform_weights, specgram) + itemwise_output = torch.stack([F.apply_beamforming(beamform_weights[i], specgram[i]) for i in range(2)]) + self.assertEqual(batchwise_output, itemwise_output) diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py index 5b882946c40..7d2dd839ae8 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py @@ -644,6 +644,59 @@ def func(tensor): tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2)) self._assert_consistency_complex(func, tensor) + def test_compute_power_spectral_density_matrix(self): + def func(tensor): + return F.compute_power_spectral_density_matrix(tensor) + + tensor = torch.rand(2, 201, 100, dtype=torch.cfloat) + self._assert_consistency(func, tensor) + + def test_compute_mvdr_weights_souden(self): + def func(_): + channel = 4 + psd_speech = torch.rand(129, channel, channel, dtype=torch.cfloat) + psd_noise = torch.rand(129, channel, channel, dtype=torch.cfloat) + return F.compute_mvdr_weights_souden(psd_speech, psd_noise, 0) + + dummy = torch.rand(1, 1) + self._assert_consistency(func, dummy) + + def test_compute_mvdr_weights_rtf(self): + def func(_): + channel = 4 + rtf = torch.rand(129, channel, 1, dtype=torch.cfloat) + psd_noise = torch.rand(129, channel, channel, dtype=torch.cfloat) + return F.compute_mvdr_weights_rtf(rtf, psd_noise, 0) + + dummy = torch.rand(1, 1) + self._assert_consistency(func, dummy) + + def test_compute_compute_rtf_evd(self): + def func(tensor): + return F.compute_rtf_evd(tensor) + + tensor = torch.rand(129, 4, 4, dtype=torch.cfloat) + self._assert_consistency_complex(func, tensor) + + def test_compute_compute_rtf_power(self): + def func(_): + channel = 4 + psd_speech = torch.rand(129, channel, channel, dtype=torch.cfloat) + psd_noise = torch.rand(129, channel, channel, dtype=torch.cfloat) + return F.compute_rtf_power(psd_speech, psd_noise, 0) + + dummy = torch.rand(1, 1) + self._assert_consistency(func, dummy) + + def test_apply_beamforming(self): + def func(_): + beamform_weights = torch.rand(201, 6, dtype=torch.cfloat) + specgram = torch.rand(6, 201, 100, dtype=torch.cfloat) + return F.apply_beamforming(beamform_weights, specgram) + + dummy = torch.rand(1, 1) + self._assert_consistency(func, dummy) + class FunctionalFloat32Only(TestBaseMixin): def test_rnnt_loss(self):