Skip to content

Resample kernel creation uses loops... #2414

@xvdp

Description

@xvdp

🐛 Describe the bug

torchaudio\functional\functionalpy def _get_sinc_resample_kernel() is slower than it should be because it uses loops instead to taking advantange of torch broadcasing. A simple rewrite will make these 9x faster.

There are other issues with resampling kernels that still cause incredible slowness if the gcd of the sample rates does not reduce them. This comment does not address these.

below some sample code of the fix, and a paraphrase of the existing code.
Other than the removing loops, a few observations
a. using torch.float64 does nothing
b. cuda is slower than cpu on both occasions so if a cuda kernel is required it should be cast at the end of kernel building
c. in place operations are faster, if no grad is required, but there is no reason why grad should be applied to these kernels

  1. faster
import math
import torch
import time
def r_kernel(orig_freq=96000, new_freq=41000, lowpass_filter_width=6, rolloff=0.99, dtype=torch.float32, device=None, resampling_method = "sinc_interpolation", beta = 14.769656459379492):
    """ remove loops, if cuda, set it at end, too many assignments to benefit from cuda
    ~ 9x faster
    """
    _time = time.time()

    with torch.no_grad():
        gcd = math.gcd(orig_freq, new_freq)
        orig_freq //= gcd
        new_freq //= gcd

        _base_freq = min(orig_freq, new_freq) * rolloff
        _scale = _base_freq / orig_freq
        _one = torch.tensor(1.0, dtype=dtype)
        width = math.ceil(lowpass_filter_width/_scale)
        _idx = torch.arange(-width, width + orig_freq, dtype=dtype)[None, None].div(orig_freq)
        _t = torch.arange(0, -new_freq, -1, dtype=dtype)[:, None, None].div(new_freq) + _idx
        _t.mul_(_base_freq).clamp_(-lowpass_filter_width, lowpass_filter_width)

        if resampling_method == "sinc_interpolation":
            _t.mul_(torch.pi)
            _window = _t.div(2*lowpass_filter_width).cos().pow_(2.)
        else: # kaiser window
            _beta = torch.as_tensor(beta, dtype=dtype)
            _window = _beta.mul((1 - _t.div(lowpass_filter_width).pow(2.)).sqrt()).div(_beta.i0())
            _t.mul_(torch.pi)  
        kernel = torch.where(_t == 0, _one, _t.sin().div(_t))
        kernel.mul_(_window)
        kernel.mul_(_scale).to(device=device)

        print(f"elapsed {1e3*(time.time() - _time):.3f} ms")
        return kernel, width
  1. paraphrase of existing code
import math
import torch
import time
def old_kernel(orig_freq=96000, new_freq=41000, lowpass_filter_width=6, rolloff=0.99, dtype=torch.float32, device=None):
    _time = time.time()
    gcd = math.gcd(orig_freq, new_freq)
    orig_freq //= gcd
    new_freq //= gcd
    base_freq = min(orig_freq, new_freq) * rolloff
    width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
    idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype)
    kernels = []
    for i in range(new_freq):
        t = ((-i / new_freq + idx / orig_freq) * base_freq).clamp(-lowpass_filter_width, lowpass_filter_width)* torch.pi
        kernel = torch.where(t == 0, torch.tensor(1.0).to(t), torch.sin(t) / t) * torch.cos(t/ lowpass_filter_width / 2) ** 2
        kernels.append(kernel)
    scale = base_freq / orig_freq
    kernels = torch.stack(kernels).view(new_freq, 1, -1).mul_(scale)
    if dtype is None:
        kernels = kernels.to(dtype=torch.float32)
    if device == 'cuda':
        torch.cuda.synchronize()
    print(f"elapsed {1e3*(time.time() - _time):.3f} ms")
    return kernels, width

Versions

Collecting environment information...
PyTorch version: 1.13.0.dev20220526
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.22.4
Libc version: glibc-2.27

Python version: 3.10.4 | packaged by conda-forge | (main, Mar 24 2022, 17:39:04) [GCC 10.3.0] (64-bit runtime)
Python platform: Linux-5.4.0-113-generic-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: 11.6.55
GPU models and configuration: GPU 0: NVIDIA TITAN RTX
Nvidia driver version: 510.39.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.22.3
[pip3] torch==1.13.0.dev20220526
[pip3] torchaudio==0.12.0a0+b7624c6
[pip3] torchvision==0.14.0.dev20220526
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 h2bc3f7f_2
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py310ha2c4b55_0 conda-forge
[conda] mkl_fft 1.3.1 py310h2b4bcf5_1 conda-forge
[conda] mkl_random 1.2.2 py310h00e6091_0
[conda] numpy 1.22.3 py310hfa59a62_0
[conda] numpy-base 1.22.3 py310h9585f30_0
[conda] pytorch 1.13.0.dev20220526 py3.10_cuda11.3_cudnn8.3.2_0 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch-nightly
[conda] torchaudio 0.12.0a0+b7624c6 dev_0
[conda] torchvision 0.14.0.dev20220526 py310_cu113 pytorch-nightly

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions