Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functionalgputest #1475

Merged
merged 7 commits into from
Apr 26, 2021
6 changes: 3 additions & 3 deletions test/torchaudio_unittest/functional/functional_cpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ def test_lfilter_9th_order_filter_stability(self):
super().test_lfilter_9th_order_filter_stability()


class TestFunctionalFloat64(Functional, FunctionalCPUOnly, PytorchTestCase):
class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')


class TestFunctionalComplex64(FunctionalComplex, FunctionalCPUOnly, PytorchTestCase):
class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')


class TestFunctionalComplex128(FunctionalComplex, FunctionalCPUOnly, PytorchTestCase):
class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
Expand Down
130 changes: 65 additions & 65 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,73 +93,17 @@ def test_spectogram_grad_at_zero(self, power):
spec.sum().backward()
assert not x.grad.isnan().sum()


class FunctionalComplex(TestBaseMixin):
complex_dtype = None
real_dtype = None
device = None

@nested_params(
[0.5, 1.01, 1.3],
[True, False],
)
def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
"""Verify the output shape of phase vocoder"""
hop_length = 256
num_freq = 1025
num_frames = 400
batch_size = 2

torch.random.manual_seed(42)
spec = torch.randn(
batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device)
if test_pseudo_complex:
spec = torch.view_as_real(spec)

phase_advance = torch.linspace(
0,
np.pi * hop_length,
num_freq,
dtype=self.real_dtype, device=self.device)[..., None]

spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)

assert spec.dim() == spec_stretch.dim()
expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))])
output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape
assert output_shape == expected_shape


class FunctionalCPUOnly(TestBaseMixin):
def test_create_fb_matrix_no_warning_high_n_freq(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.create_fb_matrix(288, 0, 8000, 128, 16000)
assert len(w) == 0

def test_create_fb_matrix_no_warning_low_n_mels(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.create_fb_matrix(201, 0, 8000, 89, 16000)
assert len(w) == 0

def test_create_fb_matrix_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.create_fb_matrix(201, 0, 8000, 128, 16000)
assert len(w) == 1

def test_compute_deltas_one_channel(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]])
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]], dtype=self.dtype, device=self.device)
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]], dtype=self.dtype, device=self.device)
computed = F.compute_deltas(specgram, win_length=3)
self.assertEqual(computed, expected)

def test_compute_deltas_two_channels(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0],
[1.0, 2.0, 3.0, 4.0]]])
[1.0, 2.0, 3.0, 4.0]]], dtype=self.dtype, device=self.device)
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5]]])
[0.5, 1.0, 1.0, 0.5]]], dtype=self.dtype, device=self.device)
computed = F.compute_deltas(specgram, win_length=3)
self.assertEqual(computed, expected)

Expand Down Expand Up @@ -190,7 +134,7 @@ def test_amplitude_to_DB_reversible(self, shape):
db_mult = math.log10(max(amin, ref))

torch.manual_seed(0)
spec = torch.rand(*shape) * 200
spec = torch.rand(*shape, dtype=self.dtype, device=self.device) * 200

# Spectrogram amplitude -> DB -> amplitude
db = F.amplitude_to_DB(spec, amplitude_mult, amin, db_mult, top_db=None)
Expand Down Expand Up @@ -218,7 +162,7 @@ def test_amplitude_to_DB_top_db_clamp(self, shape):
# each spectrogram still need to be predictable. The max determines the
# decibel cutoff, and the distance from the min must be large enough
# that it triggers a clamp.
spec = torch.rand(*shape)
spec = torch.rand(*shape, dtype=self.dtype, device=self.device)
# Ensure each spectrogram has a min of 0 and a max of 1.
spec -= spec.amin([-2, -1])[..., None, None]
spec /= spec.amax([-2, -1])[..., None, None]
Expand All @@ -245,7 +189,7 @@ def test_amplitude_to_DB_top_db_clamp(self, shape):
)
def test_complex_norm(self, shape, power):
torch.random.manual_seed(42)
complex_tensor = torch.randn(*shape)
complex_tensor = torch.randn(*shape, dtype=self.dtype, device=self.device)
expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2)
norm_tensor = F.complex_norm(complex_tensor, power)
self.assertEqual(norm_tensor, expected_norm_tensor, atol=1e-5, rtol=1e-5)
Expand All @@ -255,7 +199,7 @@ def test_complex_norm(self, shape, power):
)
def test_mask_along_axis(self, shape, mask_param, mask_value, axis):
torch.random.manual_seed(42)
specgram = torch.randn(*shape)
specgram = torch.randn(*shape, dtype=self.dtype, device=self.device)
mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis)

other_axis = 1 if axis == 2 else 2
Expand All @@ -271,7 +215,7 @@ def test_mask_along_axis(self, shape, mask_param, mask_value, axis):
@parameterized.expand(list(itertools.product([100], [0., 30.], [2, 3])))
def test_mask_along_axis_iid(self, mask_param, mask_value, axis):
torch.random.manual_seed(42)
specgrams = torch.randn(4, 2, 1025, 400)
specgrams = torch.randn(4, 2, 1025, 400, dtype=self.dtype, device=self.device)

mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis)

Expand All @@ -282,3 +226,59 @@ def test_mask_along_axis_iid(self, mask_param, mask_value, axis):

assert mask_specgrams.size() == specgrams.size()
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()


class FunctionalComplex(TestBaseMixin):
complex_dtype = None
real_dtype = None
device = None

@nested_params(
[0.5, 1.01, 1.3],
[True, False],
)
def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
"""Verify the output shape of phase vocoder"""
hop_length = 256
num_freq = 1025
num_frames = 400
batch_size = 2

torch.random.manual_seed(42)
spec = torch.randn(
batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device)
if test_pseudo_complex:
spec = torch.view_as_real(spec)

phase_advance = torch.linspace(
0,
np.pi * hop_length,
num_freq,
dtype=self.real_dtype, device=self.device)[..., None]

spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)

assert spec.dim() == spec_stretch.dim()
expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))])
output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape
assert output_shape == expected_shape


class FunctionalCPUOnly(TestBaseMixin):
def test_create_fb_matrix_no_warning_high_n_freq(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.create_fb_matrix(288, 0, 8000, 128, 16000)
assert len(w) == 0

def test_create_fb_matrix_no_warning_low_n_mels(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.create_fb_matrix(201, 0, 8000, 89, 16000)
assert len(w) == 0

def test_create_fb_matrix_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.create_fb_matrix(201, 0, 8000, 128, 16000)
assert len(w) == 1