Skip to content

Commit

Permalink
add unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
nateanl committed Feb 9, 2022
1 parent c5c04d1 commit 81c0870
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 1 deletion.
59 changes: 58 additions & 1 deletion test/torchaudio_unittest/functional/autograd_impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import partial
from typing import Callable, Tuple
from unittest import expectedFailure

import torch
import torchaudio.functional as F
Expand All @@ -9,6 +10,7 @@
from torchaudio_unittest.common_utils import (
TestBaseMixin,
get_whitenoise,
get_spectrogram,
rnnt_utils,
)

Expand All @@ -24,7 +26,7 @@ def assert_grad(
inputs_ = []
for i in inputs:
if torch.is_tensor(i):
i = i.to(dtype=self.dtype, device=self.device)
i = i.to(dtype=torch.cdouble if i.is_complex() else self.dtype, device=self.device)
if enable_all_grad:
i.requires_grad = True
inputs_.append(i)
Expand Down Expand Up @@ -250,6 +252,61 @@ def test_bandreject_biquad(self, central_freq, Q):
Q = torch.tensor(Q)
self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q))

@parameterized.expand(
[
(True,),
(False,),
]
)
def test_compute_power_spectral_density_matrix(self, use_mask):
torch.random.manual_seed(2434)
sr = 8000
x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=6)
specgram = get_spectrogram(x, n_fft=400, hop_length=100)
if use_mask:
mask = torch.rand(specgram.size())
else:
mask = None
self.assert_grad(F.compute_power_spectral_density_matrix, (specgram, mask))

def test_compute_mvdr_weights_souden(self):
torch.random.manual_seed(2434)
channel = 4
psd_speech = torch.rand(129, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(129, channel, channel, dtype=torch.cfloat)
self.assert_grad(F.compute_mvdr_weights_souden, (psd_speech, psd_noise, 0))

def test_compute_mvdr_weights_rtf(self):
torch.random.manual_seed(2434)
channel = 4
rtf = torch.rand(129, channel, 1, dtype=torch.cfloat)
psd_noise = torch.rand(129, channel, channel, dtype=torch.cfloat)
self.assert_grad(F.compute_mvdr_weights_rtf, (rtf, psd_noise, 0))

# The eigenvector can be different in different runs, expected to fail
@expectedFailure
def test_compute_compute_rtf_evd(self):
torch.random.manual_seed(2434)
channel = 4
specgram = torch.rand(channel, 201, 100, dtype=torch.cfloat)
psd_speech = F.compute_power_spectral_density_matrix(specgram)
self.assert_grad(F.compute_rtf_evd, (psd_speech,))

def test_compute_compute_rtf_power(self):
torch.random.manual_seed(2434)
channel = 4
psd_speech = torch.rand(129, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(129, channel, channel, dtype=torch.cfloat)
self.assert_grad(F.compute_rtf_power, (psd_speech, psd_noise, 0))

def test_apply_beamforming(self):
torch.random.manual_seed(2434)
sr = 8000
x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=6)
specgram = get_spectrogram(x, n_fft=400, hop_length=100)
beamform_weights = torch.rand(201, 6, dtype=torch.cfloat)
self.assert_grad(F.apply_beamforming, (beamform_weights, specgram))


class AutogradFloat32(TestBaseMixin):
def assert_grad(
Expand Down
56 changes: 56 additions & 0 deletions test/torchaudio_unittest/functional/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,59 @@ def test_filtfilt(self):
itemwise_output = torch.stack([F.filtfilt(x[i], a[i], b[i]) for i in range(self.batch_size)])

self.assertEqual(batchwise_output, itemwise_output)

def test_compute_power_spectral_density_matrix(self):
sample_rate = 44100
waveform = common_utils.get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=6)
specgram = common_utils.get_spectrogram(waveform, n_fft=400, hop_length=100)
specgram = specgram.view(2, 3, specgram.size(-2), specgram.size(-1))
batchwise_output = F.compute_power_spectral_density_matrix(specgram)
itemwise_output = torch.stack([F.compute_power_spectral_density_matrix(specgram[i]) for i in range(2)])
self.assertEqual(batchwise_output, itemwise_output)

def test_compute_mvdr_weights_souden(self):
torch.random.manual_seed(2434)
channel = 4
psd_speech = torch.rand(2, 129, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(2, 129, channel, channel, dtype=torch.cfloat)
batchwise_output = F.compute_mvdr_weights_souden(psd_speech, psd_noise, 0)
itemwise_output = torch.stack([F.compute_mvdr_weights_souden(psd_speech[i], psd_noise[i], 0) for i in range(2)])
self.assertEqual(batchwise_output, itemwise_output)

def test_compute_mvdr_weights_rtf(self):
torch.random.manual_seed(2434)
channel = 4
rtf = torch.rand(2, 129, channel, 1, dtype=torch.cfloat)
psd_noise = torch.rand(2, 129, channel, channel, dtype=torch.cfloat)
batchwise_output = F.compute_mvdr_weights_rtf(rtf, psd_noise, 0)
itemwise_output = torch.stack([F.compute_mvdr_weights_rtf(rtf[i], psd_noise[i], 0) for i in range(2)])
self.assertEqual(batchwise_output, itemwise_output)

def test_compute_compute_rtf_evd(self):
torch.random.manual_seed(2434)
channel = 4
specgram = torch.rand(2, channel, 201, 100, dtype=torch.cfloat)
psd_speech = F.compute_power_spectral_density_matrix(specgram)
batchwise_output = F.compute_rtf_evd(psd_speech)
itemwise_output = torch.stack([F.compute_rtf_evd(psd_speech[i]) for i in range(2)])
self.assertEqual(batchwise_output, itemwise_output)

def test_compute_compute_rtf_power(self):
torch.random.manual_seed(2434)
channel = 4
psd_speech = torch.rand(2, 129, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(2, 129, channel, channel, dtype=torch.cfloat)
batchwise_output = F.compute_rtf_power(psd_speech, psd_noise, 0)
itemwise_output = torch.stack([F.compute_rtf_power(psd_speech[i], psd_noise[i], 0) for i in range(2)])
self.assertEqual(batchwise_output, itemwise_output)

def test_apply_beamforming(self):
torch.random.manual_seed(2434)
sr = 8000
x = common_utils.get_whitenoise(sample_rate=sr, duration=0.05, n_channels=6)
specgram = common_utils.get_spectrogram(x, n_fft=400, hop_length=100)
specgram = specgram.view(2, 3, specgram.size(-2), specgram.size(-1))
beamform_weights = torch.rand(2, 201, 3, dtype=torch.cfloat)
batchwise_output = F.apply_beamforming(beamform_weights, specgram)
itemwise_output = torch.stack([F.apply_beamforming(beamform_weights[i], specgram[i]) for i in range(2)])
self.assertEqual(batchwise_output, itemwise_output)
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,59 @@ def func(tensor):
tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2))
self._assert_consistency_complex(func, tensor)

def test_compute_power_spectral_density_matrix(self):
def func(tensor):
return F.compute_power_spectral_density_matrix(tensor)

tensor = torch.rand(2, 201, 100, dtype=torch.cfloat)
self._assert_consistency(func, tensor)

def test_compute_mvdr_weights_souden(self):
def func(_):
channel = 4
psd_speech = torch.rand(129, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(129, channel, channel, dtype=torch.cfloat)
return F.compute_mvdr_weights_souden(psd_speech, psd_noise, 0)

dummy = torch.rand(1, 1)
self._assert_consistency(func, dummy)

def test_compute_mvdr_weights_rtf(self):
def func(_):
channel = 4
rtf = torch.rand(129, channel, 1, dtype=torch.cfloat)
psd_noise = torch.rand(129, channel, channel, dtype=torch.cfloat)
return F.compute_mvdr_weights_rtf(rtf, psd_noise, 0)

dummy = torch.rand(1, 1)
self._assert_consistency(func, dummy)

def test_compute_compute_rtf_evd(self):
def func(tensor):
return F.compute_rtf_evd(tensor)

tensor = torch.rand(129, 4, 4, dtype=torch.cfloat)
self._assert_consistency_complex(func, tensor)

def test_compute_compute_rtf_power(self):
def func(_):
channel = 4
psd_speech = torch.rand(129, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(129, channel, channel, dtype=torch.cfloat)
return F.compute_rtf_power(psd_speech, psd_noise, 0)

dummy = torch.rand(1, 1)
self._assert_consistency(func, dummy)

def test_apply_beamforming(self):
def func(_):
beamform_weights = torch.rand(201, 6, dtype=torch.cfloat)
specgram = torch.rand(6, 201, 100, dtype=torch.cfloat)
return F.apply_beamforming(beamform_weights, specgram)

dummy = torch.rand(1, 1)
self._assert_consistency(func, dummy)


class FunctionalFloat32Only(TestBaseMixin):
def test_rnnt_loss(self):
Expand Down

0 comments on commit 81c0870

Please sign in to comment.