-
Notifications
You must be signed in to change notification settings - Fork 732
Description
🐛 Describe the bug
torchaudio\functional\functionalpy def _get_sinc_resample_kernel() is slower than it should be because it uses loops instead to taking advantange of torch broadcasing. A simple rewrite will make these 9x faster.
There are other issues with resampling kernels that still cause incredible slowness if the gcd of the sample rates does not reduce them. This comment does not address these.
below some sample code of the fix, and a paraphrase of the existing code.
Other than the removing loops, a few observations
a. using torch.float64 does nothing
b. cuda is slower than cpu on both occasions so if a cuda kernel is required it should be cast at the end of kernel building
c. in place operations are faster, if no grad is required, but there is no reason why grad should be applied to these kernels
- faster
import math
import torch
import time
def r_kernel(orig_freq=96000, new_freq=41000, lowpass_filter_width=6, rolloff=0.99, dtype=torch.float32, device=None, resampling_method = "sinc_interpolation", beta = 14.769656459379492):
""" remove loops, if cuda, set it at end, too many assignments to benefit from cuda
~ 9x faster
"""
_time = time.time()
with torch.no_grad():
gcd = math.gcd(orig_freq, new_freq)
orig_freq //= gcd
new_freq //= gcd
_base_freq = min(orig_freq, new_freq) * rolloff
_scale = _base_freq / orig_freq
_one = torch.tensor(1.0, dtype=dtype)
width = math.ceil(lowpass_filter_width/_scale)
_idx = torch.arange(-width, width + orig_freq, dtype=dtype)[None, None].div(orig_freq)
_t = torch.arange(0, -new_freq, -1, dtype=dtype)[:, None, None].div(new_freq) + _idx
_t.mul_(_base_freq).clamp_(-lowpass_filter_width, lowpass_filter_width)
if resampling_method == "sinc_interpolation":
_t.mul_(torch.pi)
_window = _t.div(2*lowpass_filter_width).cos().pow_(2.)
else: # kaiser window
_beta = torch.as_tensor(beta, dtype=dtype)
_window = _beta.mul((1 - _t.div(lowpass_filter_width).pow(2.)).sqrt()).div(_beta.i0())
_t.mul_(torch.pi)
kernel = torch.where(_t == 0, _one, _t.sin().div(_t))
kernel.mul_(_window)
kernel.mul_(_scale).to(device=device)
print(f"elapsed {1e3*(time.time() - _time):.3f} ms")
return kernel, width- paraphrase of existing code
import math
import torch
import time
def old_kernel(orig_freq=96000, new_freq=41000, lowpass_filter_width=6, rolloff=0.99, dtype=torch.float32, device=None):
_time = time.time()
gcd = math.gcd(orig_freq, new_freq)
orig_freq //= gcd
new_freq //= gcd
base_freq = min(orig_freq, new_freq) * rolloff
width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype)
kernels = []
for i in range(new_freq):
t = ((-i / new_freq + idx / orig_freq) * base_freq).clamp(-lowpass_filter_width, lowpass_filter_width)* torch.pi
kernel = torch.where(t == 0, torch.tensor(1.0).to(t), torch.sin(t) / t) * torch.cos(t/ lowpass_filter_width / 2) ** 2
kernels.append(kernel)
scale = base_freq / orig_freq
kernels = torch.stack(kernels).view(new_freq, 1, -1).mul_(scale)
if dtype is None:
kernels = kernels.to(dtype=torch.float32)
if device == 'cuda':
torch.cuda.synchronize()
print(f"elapsed {1e3*(time.time() - _time):.3f} ms")
return kernels, widthVersions
Collecting environment information...
PyTorch version: 1.13.0.dev20220526
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.22.4
Libc version: glibc-2.27
Python version: 3.10.4 | packaged by conda-forge | (main, Mar 24 2022, 17:39:04) [GCC 10.3.0] (64-bit runtime)
Python platform: Linux-5.4.0-113-generic-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: 11.6.55
GPU models and configuration: GPU 0: NVIDIA TITAN RTX
Nvidia driver version: 510.39.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.22.3
[pip3] torch==1.13.0.dev20220526
[pip3] torchaudio==0.12.0a0+b7624c6
[pip3] torchvision==0.14.0.dev20220526
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 h2bc3f7f_2
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py310ha2c4b55_0 conda-forge
[conda] mkl_fft 1.3.1 py310h2b4bcf5_1 conda-forge
[conda] mkl_random 1.2.2 py310h00e6091_0
[conda] numpy 1.22.3 py310hfa59a62_0
[conda] numpy-base 1.22.3 py310h9585f30_0
[conda] pytorch 1.13.0.dev20220526 py3.10_cuda11.3_cudnn8.3.2_0 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch-nightly
[conda] torchaudio 0.12.0a0+b7624c6 dev_0
[conda] torchvision 0.14.0.dev20220526 py310_cu113 pytorch-nightly