Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make resample kernel generation faster #2415

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 23 additions & 30 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings
from collections.abc import Sequence
from typing import Optional, Tuple, Union
from numpy import roll
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used? NumPy is not mandatory requirement of torchaudio, so we would like to avoid the use of NumPy.


import torch
import torchaudio
Expand Down Expand Up @@ -1392,9 +1393,9 @@ def _get_sinc_resample_kernel(
lowpass_filter_width: int,
rolloff: float,
resampling_method: str,
beta: Optional[float],
beta: float = 14.769656459379492,
device: torch.device = torch.device("cpu"),
dtype: Optional[torch.dtype] = None,
dtype: Optional[torch.dtype] = torch.float32,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep the dtype effectiveness? The reason resample supports torch.float64 for the sake of use cases other than DL, which requires higher precision for quality.

):

if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
Expand All @@ -1414,13 +1415,11 @@ def _get_sinc_resample_kernel(
new_freq = int(new_freq) // gcd

assert lowpass_filter_width > 0
kernels = []
base_freq = min(orig_freq, new_freq)
base_freq = min(orig_freq, new_freq) * rolloff
# This will perform antialiasing filtering by removing the highest frequencies.
# At first I thought I only needed this when downsampling, but when upsampling
# you will get edge artifacts without this, as the edge is equivalent to zero padding,
# which will add high freq artifacts.
base_freq *= rolloff

# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
# using the sinc interpolation formula:
Expand All @@ -1439,37 +1438,31 @@ def _get_sinc_resample_kernel(
# = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
# so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
# This will explain the F.conv1d after, with a stride of orig_freq.
width = math.ceil(lowpass_filter_width * orig_freq / base_freq)

scale = base_freq / orig_freq
width = math.ceil(lowpass_filter_width / scale)
# If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
idx_dtype = dtype if dtype is not None else torch.float64
idx = torch.arange(-width, width + orig_freq, device=device, dtype=idx_dtype)

for i in range(new_freq):
t = (-i / new_freq + idx / orig_freq) * base_freq
t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
one = torch.tensor(1.0, dtype=dtype)
idx = torch.arange(-width, width + orig_freq, dtype=dtype)[None, None].div(orig_freq)
t = torch.arange(0, -new_freq, -1, dtype=dtype)[:, None, None].div(new_freq) + idx
t.mul_(base_freq).clamp_(-lowpass_filter_width, lowpass_filter_width)

if resampling_method == "sinc_interpolation":
t.mul_(torch.pi)
window = t.div(2*lowpass_filter_width).cos().pow_(2.)
else: # kaiser window
beta_tensor = torch.as_tensor(beta, dtype=dtype)
window = beta_tensor.mul((1 - t.div(lowpass_filter_width).pow(2.)).sqrt()).div(beta_tensor.i0())
t.mul_(torch.pi)

kernels = torch.where(t == 0, one, t.sin().div(t))
kernels.mul_(window)
kernels.mul_(scale).to(device=device)

# we do not use built in torch windows here as we need to evaluate the window
# at specific positions, not over a regular grid.
if resampling_method == "sinc_interpolation":
window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2
else:
# kaiser_window
if beta is None:
beta = 14.769656459379492
beta_tensor = torch.tensor(float(beta))
window = torch.i0(beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta_tensor)
t *= math.pi
kernel = torch.where(t == 0, torch.tensor(1.0).to(t), torch.sin(t) / t)
kernel.mul_(window)
kernels.append(kernel)

scale = base_freq / orig_freq
kernels = torch.stack(kernels).view(new_freq, 1, -1).mul_(scale)
if dtype is None:
kernels = kernels.to(dtype=torch.float32)
return kernels, width


Expand Down