Skip to content

Commit

Permalink
Move resample to functional and add librosa comparison (#1402)
Browse files Browse the repository at this point in the history
This PR additionally adds batching to kaldi compliance resample interface.
  • Loading branch information
Caroline Chen authored Mar 22, 2021
1 parent dd76e9d commit 14dd917
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 91 deletions.
7 changes: 6 additions & 1 deletion docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ apply_codec
-----------

.. autofunction:: apply_codec

:hidden:`Complex Utility`
~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -230,3 +230,8 @@ vad
---------------------------

.. autofunction:: spectral_centroid

:hidden:`resample`
---------------------------

.. autofunction:: resample
19 changes: 19 additions & 0 deletions test/torchaudio_unittest/functional/librosa_compatibility_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,25 @@ def test_amplitude_to_DB(self):

self.assertEqual(ta_out, lr_out, atol=5e-5, rtol=1e-5)

def test_resample(self):
input_path = common_utils.get_asset_path('sinewave.wav')
waveform, sample_rate = common_utils.load_wav(input_path)

upsample_rate = sample_rate * 2
downsample_rate = sample_rate // 2

ta_upsampled = F.resample(waveform, sample_rate, upsample_rate)
lr_upsampled = librosa.resample(waveform.squeeze(0).numpy(), sample_rate, upsample_rate)
lr_upsampled = torch.from_numpy(lr_upsampled).unsqueeze(0)

self.assertEqual(ta_upsampled, lr_upsampled, atol=1e-2, rtol=1e-5)

ta_downsampled = F.resample(waveform, sample_rate, downsample_rate)
lr_downsampled = librosa.resample(waveform.squeeze(0).numpy(), sample_rate, downsample_rate)
lr_downsampled = torch.from_numpy(lr_downsampled).unsqueeze(0)

self.assertEqual(ta_downsampled, lr_downsampled, atol=1e-2, rtol=1e-5)


@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TestPhaseVocoder(common_utils.TorchaudioTestCase):
Expand Down
81 changes: 4 additions & 77 deletions torchaudio/compliance/kaldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import math
import torch
from torch import Tensor
from torch.nn import functional as F

import torchaudio
import torchaudio._internal.fft
Expand Down Expand Up @@ -753,71 +752,16 @@ def mfcc(
return feature


def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_width: int,
device: torch.device, dtype: torch.dtype):
assert lowpass_filter_width > 0
kernels = []
base_freq = min(orig_freq, new_freq)
# This will perform antialiasing filtering by removing the highest frequencies.
# At first I thought I only needed this when downsampling, but when upsampling
# you will get edge artifacts without this, as the edge is equivalent to zero padding,
# which will add high freq artifacts.
base_freq *= 0.99

# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
# using the sinc interpolation formula:
# x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
# We can then sample the function x(t) with a different sample rate:
# y[j] = x(j / new_freq)
# or,
# y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))

# We see here that y[j] is the convolution of x[i] with a specific filter, for which
# we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
# But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
# Indeed:
# y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
# = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
# = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
# so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
# This will explain the F.conv1d after, with a stride of orig_freq.
width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
# If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype)

for i in range(new_freq):
t = (-i / new_freq + idx / orig_freq) * base_freq
t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
t *= math.pi
# we do not use torch.hann_window here as we need to evaluate the window
# at specific positions, not over a regular grid.
window = torch.cos(t / lowpass_filter_width / 2)**2
kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t)
kernel.mul_(window)
kernels.append(kernel)

scale = base_freq / orig_freq
return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width


def resample_waveform(waveform: Tensor,
orig_freq: float,
new_freq: float,
lowpass_filter_width: int = 6) -> Tensor:
r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e
the output signal has a frequency of ``new_freq``). It uses sinc/bandlimited interpolation to
upsample/downsample the signal.
r"""Resamples the waveform at the new frequency.
https://ccrma.stanford.edu/~jos/resample/Theory_Ideal_Bandlimited_Interpolation.html
https://github.com/kaldi-asr/kaldi/blob/master/src/feat/resample.h#L56
This is a wrapper around ``torchaudio.functional.resample``.
Args:
waveform (Tensor): The input signal of size (c, n)
waveform (Tensor): The input signal of size (..., time)
orig_freq (float): The original frequency of the signal
new_freq (float): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
Expand All @@ -826,21 +770,4 @@ def resample_waveform(waveform: Tensor,
Returns:
Tensor: The waveform at the new frequency
"""
assert waveform.dim() == 2
assert orig_freq > 0.0 and new_freq > 0.0

orig_freq = int(orig_freq)
new_freq = int(new_freq)
gcd = math.gcd(orig_freq, new_freq)
orig_freq = orig_freq // gcd
new_freq = new_freq // gcd

kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width,
waveform.device, waveform.dtype)

num_wavs, length = waveform.shape
waveform = F.pad(waveform, (width, width + orig_freq))
resampled = F.conv1d(waveform[:, None], kernel, stride=orig_freq)
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
target_length = int(math.ceil(new_freq * length / orig_freq))
return resampled[..., :target_length]
return torchaudio.functional.resample(waveform, orig_freq, new_freq, lowpass_filter_width)
4 changes: 3 additions & 1 deletion torchaudio/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
spectrogram,
spectral_centroid,
apply_codec,
resample,
)
from .filtering import (
allpass_biquad,
Expand Down Expand Up @@ -85,5 +86,6 @@
'riaa_biquad',
'treble_biquad',
'vad',
'apply_codec'
'apply_codec',
'resample',
]
103 changes: 103 additions & 0 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
'sliding_window_cmn',
"spectral_centroid",
"apply_codec",
"resample",
]


Expand Down Expand Up @@ -1209,3 +1210,105 @@ def compute_kaldi_pitch(
)
result = result.reshape(shape[:-1] + result.shape[-2:])
return result


def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_width: int,
device: torch.device, dtype: torch.dtype):
assert lowpass_filter_width > 0
kernels = []
base_freq = min(orig_freq, new_freq)
# This will perform antialiasing filtering by removing the highest frequencies.
# At first I thought I only needed this when downsampling, but when upsampling
# you will get edge artifacts without this, as the edge is equivalent to zero padding,
# which will add high freq artifacts.
base_freq *= 0.99

# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
# using the sinc interpolation formula:
# x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
# We can then sample the function x(t) with a different sample rate:
# y[j] = x(j / new_freq)
# or,
# y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))

# We see here that y[j] is the convolution of x[i] with a specific filter, for which
# we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
# But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
# Indeed:
# y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
# = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
# = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
# so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
# This will explain the F.conv1d after, with a stride of orig_freq.
width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
# If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype)

for i in range(new_freq):
t = (-i / new_freq + idx / orig_freq) * base_freq
t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
t *= math.pi
# we do not use torch.hann_window here as we need to evaluate the window
# at specific positions, not over a regular grid.
window = torch.cos(t / lowpass_filter_width / 2)**2
kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t)
kernel.mul_(window)
kernels.append(kernel)

scale = base_freq / orig_freq
return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width


def resample(
waveform: Tensor,
orig_freq: float,
new_freq: float,
lowpass_filter_width: int = 6
) -> Tensor:
r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e
the output signal has a frequency of ``new_freq``). It uses sinc/bandlimited interpolation to
upsample/downsample the signal.
https://ccrma.stanford.edu/~jos/resample/Theory_Ideal_Bandlimited_Interpolation.html
https://github.com/kaldi-asr/kaldi/blob/master/src/feat/resample.h#L56
Args:
waveform (Tensor): The input signal of dimension (..., time)
orig_freq (float): The original frequency of the signal
new_freq (float): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
Returns:
Tensor: The waveform at the new frequency of dimension (..., time).
"""
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])

assert orig_freq > 0.0 and new_freq > 0.0

orig_freq = int(orig_freq)
new_freq = int(new_freq)
gcd = math.gcd(orig_freq, new_freq)
orig_freq = orig_freq // gcd
new_freq = new_freq // gcd

kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width,
waveform.device, waveform.dtype)

num_wavs, length = waveform.shape
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
target_length = int(math.ceil(new_freq * length / orig_freq))
resampled = resampled[..., :target_length]

# unpack batch
resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
return resampled
13 changes: 1 addition & 12 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch
from torch import Tensor
from torchaudio import functional as F
from torchaudio.compliance import kaldi


__all__ = [
Expand Down Expand Up @@ -649,17 +648,7 @@ def forward(self, waveform: Tensor) -> Tensor:
Tensor: Output signal of dimension (..., time).
"""
if self.resampling_method == 'sinc_interpolation':

# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])

waveform = kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)

# unpack batch
waveform = waveform.view(shape[:-1] + waveform.shape[-1:])

return waveform
return F.resample(waveform, self.orig_freq, self.new_freq)

raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))

Expand Down

0 comments on commit 14dd917

Please sign in to comment.