diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index f7109c6531..b9eea1211f 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -5,6 +5,7 @@ import warnings from collections.abc import Sequence from typing import Optional, Tuple, Union +from numpy import roll import torch import torchaudio @@ -1392,9 +1393,9 @@ def _get_sinc_resample_kernel( lowpass_filter_width: int, rolloff: float, resampling_method: str, - beta: Optional[float], + beta: float = 14.769656459379492, device: torch.device = torch.device("cpu"), - dtype: Optional[torch.dtype] = None, + dtype: Optional[torch.dtype] = torch.float32, ): if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq): @@ -1414,13 +1415,11 @@ def _get_sinc_resample_kernel( new_freq = int(new_freq) // gcd assert lowpass_filter_width > 0 - kernels = [] - base_freq = min(orig_freq, new_freq) + base_freq = min(orig_freq, new_freq) * rolloff # This will perform antialiasing filtering by removing the highest frequencies. # At first I thought I only needed this when downsampling, but when upsampling # you will get edge artifacts without this, as the edge is equivalent to zero padding, # which will add high freq artifacts. - base_freq *= rolloff # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor) # using the sinc interpolation formula: @@ -1439,37 +1438,31 @@ def _get_sinc_resample_kernel( # = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq)) # so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`. # This will explain the F.conv1d after, with a stride of orig_freq. - width = math.ceil(lowpass_filter_width * orig_freq / base_freq) + + scale = base_freq / orig_freq + width = math.ceil(lowpass_filter_width / scale) # If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e., # they will have a lot of almost zero values to the left or to the right... # There is probably a way to evaluate those filters more efficiently, but this is kept for # future work. - idx_dtype = dtype if dtype is not None else torch.float64 - idx = torch.arange(-width, width + orig_freq, device=device, dtype=idx_dtype) - for i in range(new_freq): - t = (-i / new_freq + idx / orig_freq) * base_freq - t = t.clamp_(-lowpass_filter_width, lowpass_filter_width) + one = torch.tensor(1.0, dtype=dtype) + idx = torch.arange(-width, width + orig_freq, dtype=dtype)[None, None].div(orig_freq) + t = torch.arange(0, -new_freq, -1, dtype=dtype)[:, None, None].div(new_freq) + idx + t.mul_(base_freq).clamp_(-lowpass_filter_width, lowpass_filter_width) + + if resampling_method == "sinc_interpolation": + t.mul_(torch.pi) + window = t.div(2*lowpass_filter_width).cos().pow_(2.) + else: # kaiser window + beta_tensor = torch.as_tensor(beta, dtype=dtype) + window = beta_tensor.mul((1 - t.div(lowpass_filter_width).pow(2.)).sqrt()).div(beta_tensor.i0()) + t.mul_(torch.pi) + + kernels = torch.where(t == 0, one, t.sin().div(t)) + kernels.mul_(window) + kernels.mul_(scale).to(device=device) - # we do not use built in torch windows here as we need to evaluate the window - # at specific positions, not over a regular grid. - if resampling_method == "sinc_interpolation": - window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2 - else: - # kaiser_window - if beta is None: - beta = 14.769656459379492 - beta_tensor = torch.tensor(float(beta)) - window = torch.i0(beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta_tensor) - t *= math.pi - kernel = torch.where(t == 0, torch.tensor(1.0).to(t), torch.sin(t) / t) - kernel.mul_(window) - kernels.append(kernel) - - scale = base_freq / orig_freq - kernels = torch.stack(kernels).view(new_freq, 1, -1).mul_(scale) - if dtype is None: - kernels = kernels.to(dtype=torch.float32) return kernels, width