Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen committed Mar 22, 2021
1 parent 35bc582 commit ae85994
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ to use and feel like a natural extension.
- Common audio transforms
- [Spectrogram, AmplitudeToDB, MelScale, MelSpectrogram, MFCC, MuLawEncoding, MuLawDecoding, Resample](http://pytorch.org/audio/stable/transforms.html)
- Compliance interfaces: Run code using PyTorch that align with other libraries
- [Kaldi: spectrogram, fbank, mfcc](https://pytorch.org/audio/stable/compliance.kaldi.html)
- [Kaldi: spectrogram, fbank, mfcc, resample_waveform](https://pytorch.org/audio/stable/compliance.kaldi.html)

Dependencies
------------
Expand Down
5 changes: 5 additions & 0 deletions docs/source/compliance.kaldi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,8 @@ Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: mfcc

:hidden:`resample_waveform`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: resample_waveform
17 changes: 8 additions & 9 deletions test/torchaudio_unittest/compliance_kaldi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.compliance.kaldi as kaldi

from torchaudio_unittest import common_utils
Expand Down Expand Up @@ -179,21 +178,21 @@ def test_mfcc_empty(self):

def test_resample_waveform(self):
def get_output_fn(sound, args):
output = F.resample(sound.to(torch.float32), args[1], args[2])
output = kaldi.resample_waveform(sound.to(torch.float32), args[1], args[2])
return output

self._compliance_test_helper(self.test2_filepath, 'resample', 32, 3, get_output_fn, atol=1e-2, rtol=1e-5)

def test_resample_waveform_upsample_size(self):
upsample_sound = F.resample(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr * 2)
upsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr * 2)
self.assertTrue(upsample_sound.size(-1) == self.test1_signal.size(-1) * 2)

def test_resample_waveform_downsample_size(self):
downsample_sound = F.resample(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr // 2)
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr // 2)
self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1) // 2)

def test_resample_waveform_identity_size(self):
downsample_sound = F.resample(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr)
downsample_sound = kaldi.resample_waveform(self.test1_signal, self.test1_signal_sr, self.test1_signal_sr)
self.assertTrue(downsample_sound.size(-1) == self.test1_signal.size(-1))

def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_factor=None,
Expand All @@ -213,7 +212,7 @@ def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_fact
original_timestamps = torch.arange(0, duration, 1.0 / sample_rate)

sound = 123 * torch.cos(2 * math.pi * 3 * original_timestamps).unsqueeze(0)
estimate = F.resample(sound, sample_rate, new_sample_rate).squeeze()
estimate = kaldi.resample_waveform(sound, sample_rate, new_sample_rate).squeeze()

new_timestamps = torch.arange(0, duration, 1.0 / new_sample_rate)[:estimate.size(0)]
ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps)
Expand All @@ -240,11 +239,11 @@ def test_resample_waveform_multi_channel(self):
for i in range(num_channels):
multi_sound[i, :] *= (i + 1) * 1.5

multi_sound_sampled = F.resample(multi_sound, self.test1_signal_sr, self.test1_signal_sr // 2)
multi_sound_sampled = kaldi.resample_waveform(multi_sound, self.test1_signal_sr, self.test1_signal_sr // 2)

# check that sampling is same whether using separately or in a tensor of size (c, n)
for i in range(num_channels):
single_channel = self.test1_signal * (i + 1) * 1.5
single_channel_sampled = F.resample(single_channel, self.test1_signal_sr,
self.test1_signal_sr // 2)
single_channel_sampled = kaldi.resample_waveform(single_channel, self.test1_signal_sr,
self.test1_signal_sr // 2)
self.assertEqual(multi_sound_sampled[i, :], single_channel_sampled[0], rtol=1e-4, atol=1e-7)
22 changes: 22 additions & 0 deletions torchaudio/compliance/kaldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
'mfcc',
'vtln_warp_freq',
'vtln_warp_mel_freq',
'resample_waveform',
]

# numeric_limits<float>::epsilon() 1.1920928955078125e-07
Expand Down Expand Up @@ -749,3 +750,24 @@ def mfcc(

feature = _subtract_column_mean(feature, subtract_mean)
return feature


def resample_waveform(waveform: Tensor,
orig_freq: float,
new_freq: float,
lowpass_filter_width: int = 6) -> Tensor:
r"""Resamples the waveform at the new frequency.
This is a wrapper around ``torchaudio.functional.resample``.
Args:
waveform (Tensor): The input signal of size (..., time)
orig_freq (float): The original frequency of the signal
new_freq (float): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
Returns:
Tensor: The waveform at the new frequency
"""
return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width)

0 comments on commit ae85994

Please sign in to comment.