diff --git a/README.md b/README.md index fdfd826c978..2cff4bf0760 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ to use and feel like a natural extension. - Common audio transforms - [Spectrogram, AmplitudeToDB, MelScale, MelSpectrogram, MFCC, MuLawEncoding, MuLawDecoding, Resample](http://pytorch.org/audio/stable/transforms.html) - Compliance interfaces: Run code using PyTorch that align with other libraries - - [Kaldi: spectrogram, fbank, mfcc](https://pytorch.org/audio/stable/compliance.kaldi.html) + - [Kaldi: spectrogram, fbank, mfcc, resample_waveform](https://pytorch.org/audio/stable/compliance.kaldi.html) Dependencies ------------ diff --git a/docs/source/compliance.kaldi.rst b/docs/source/compliance.kaldi.rst index 72827ca3fbf..cc75021d698 100644 --- a/docs/source/compliance.kaldi.rst +++ b/docs/source/compliance.kaldi.rst @@ -29,3 +29,8 @@ Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: mfcc + +:hidden:`resample_waveform` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: resample_waveform diff --git a/test/torchaudio_unittest/compliance_kaldi_test.py b/test/torchaudio_unittest/compliance_kaldi_test.py index 00fdbfd1ddd..a98240a9b80 100644 --- a/test/torchaudio_unittest/compliance_kaldi_test.py +++ b/test/torchaudio_unittest/compliance_kaldi_test.py @@ -3,7 +3,6 @@ import torch import torchaudio -import torchaudio.functional as F import torchaudio.compliance.kaldi as kaldi from torchaudio_unittest import common_utils @@ -179,21 +178,21 @@ def test_mfcc_empty(self): def test_resample_waveform(self): def get_output_fn(sound, args): - output = F.resample(sound.to(torch.float32), args[1], args[2]) + output = kaldi.resample_waveform(sound.to(torch.float32), args[1], args[2]) return output self._compliance_test_helper(self.test2_filepath, 'resample', 32, 3, get_output_fn, atol=1e-2, rtol=1e-5) def test_resample_waveform_upsample_size(self): - upsample_sound = F.resample(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr * 2) + upsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr * 2) self.assertTrue(upsample_sound.size(-1) == self.test1_signal.size(-1) * 2) def test_resample_waveform_downsample_size(self): - downsample_sound = F.resample(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr // 2) + downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr // 2) self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1) // 2) def test_resample_waveform_identity_size(self): - downsample_sound = F.resample(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr) + downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr) self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1)) def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None, @@ -213,7 +212,7 @@ def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_fact original_timestamps = torch.arange(0, duration, 1.0 / sample_rate) sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0) - estimate = F.resample(sound, sample_rate, new_sample_rate).squeeze() + estimate = kaldi.resample_waveform(sound, sample_rate, new_sample_rate).squeeze() new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)] ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps) @@ -240,11 +239,11 @@ def test_resample_waveform_multi_channel(self): for i in range(num_channels): multi_sound[i, :] *= (i + 1) * 1.5 - multi_sound_sampled = F.resample(multi_sound, self.test1_signal_sr, self.test1_signal_sr // 2) + multi_sound_sampled = kaldi.resample_waveform(multi_sound, self.test1_signal_sr, self.test1_signal_sr // 2) # check that sampling is same whether using separately or in a tensor of size (c, n) for i in range(num_channels): single_channel = self.test1_signal * (i + 1) * 1.5 - single_channel_sampled = F.resample(single_channel, self.test1_signal_sr, - self.test1_signal_sr // 2) + single_channel_sampled = kaldi.resample_waveform(single_channel, self.test1_signal_sr, + self.test1_signal_sr // 2) self.assertEqual(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-7) diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py index 583aaffd2b5..9cdab5acdd8 100644 --- a/torchaudio/compliance/kaldi.py +++ b/torchaudio/compliance/kaldi.py @@ -18,6 +18,7 @@ 'mfcc', 'vtln_warp_freq', 'vtln_warp_mel_freq', + 'resample_waveform', ] # numeric_limits::epsilon() 1.1920928955078125e-07 @@ -749,3 +750,24 @@ def mfcc( feature = _subtract_column_mean(feature, subtract_mean) return feature + + +def resample_waveform(waveform: Tensor, + orig_freq: float, + new_freq: float, + lowpass_filter_width: int = 6) -> Tensor: + r"""Resamples the waveform at the new frequency. + + This is a wrapper around ``torchaudio.functional.resample``. + + Args: + waveform (Tensor): The input signal of size (c, n) + orig_freq (float): The original frequency of the signal + new_freq (float): The desired frequency + lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper + but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``) + + Returns: + Tensor: The waveform at the new frequency + """ + return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width)