From b2034e692919baf6918e0b9083ff5ebe22674699 Mon Sep 17 00:00:00 2001 From: xvdp Date: Thu, 26 May 2022 18:45:53 -0700 Subject: [PATCH 1/5] fix to issue #2414, kernel creation uses loops. Changes: 1. loops are replaced with broadcasting 2. device is cast at end of kernel creation 3. dtype is defaulted to float32 ( float64) does nothing valuable 4. defaults for beta and dtype are moved to function declaration --- torchaudio/functional/functional.py | 50 ++++++++++++----------------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 323783b42b..ba0f478def 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -1394,7 +1394,7 @@ def _get_sinc_resample_kernel( resampling_method: str = "sinc_interpolation", beta: Optional[float] = None, device: torch.device = torch.device("cpu"), - dtype: Optional[torch.dtype] = None, + dtype: Optional[torch.dtype] = torch.float32, ): if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq): @@ -1414,13 +1414,11 @@ def _get_sinc_resample_kernel( new_freq = int(new_freq) // gcd assert lowpass_filter_width > 0 - kernels = [] - base_freq = min(orig_freq, new_freq) + base_freq = min(orig_freq, new_freq) * rolloff # This will perform antialiasing filtering by removing the highest frequencies. # At first I thought I only needed this when downsampling, but when upsampling # you will get edge artifacts without this, as the edge is equivalent to zero padding, # which will add high freq artifacts. - base_freq *= rolloff # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor) # using the sinc interpolation formula: @@ -1439,37 +1437,31 @@ def _get_sinc_resample_kernel( # = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq)) # so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`. # This will explain the F.conv1d after, with a stride of orig_freq. - width = math.ceil(lowpass_filter_width * orig_freq / base_freq) + + scale = base_freq / orig_freq + width = math.ceil(lowpass_filter_width / scale) # If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e., # they will have a lot of almost zero values to the left or to the right... # There is probably a way to evaluate those filters more efficiently, but this is kept for # future work. - idx_dtype = dtype if dtype is not None else torch.float64 - idx = torch.arange(-width, width + orig_freq, device=device, dtype=idx_dtype) - for i in range(new_freq): - t = (-i / new_freq + idx / orig_freq) * base_freq - t = t.clamp_(-lowpass_filter_width, lowpass_filter_width) + one = torch.tensor(1.0, dtype=dtype) + idx = torch.arange(-width, width + orig_freq, dtype=dtype)[None, None].div(orig_freq) + t = torch.arange(0, -new_freq, -1, dtype=dtype)[:, None, None].div(new_freq) + idx + t.mul_(base_freq).clamp_(-lowpass_filter_width, lowpass_filter_width) + + if resampling_method == "sinc_interpolation": + t.mul_(torch.pi) + window = t.div(2*lowpass_filter_width).cos().pow_(2.) + else: # kaiser window + beta_tensor = torch.as_tensor(beta, dtype=dtype) + window = beta_tensor.mul((1 - t.div(lowpass_filter_width).pow(2.)).sqrt()).div(beta_tensor.i0()) + t.mul_(torch.pi) + + kernels = torch.where(t == 0, one, t.sin().div(t)) + kernels.mul_(window) + kernels.mul_(scale).to(device=device) - # we do not use built in torch windows here as we need to evaluate the window - # at specific positions, not over a regular grid. - if resampling_method == "sinc_interpolation": - window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2 - else: - # kaiser_window - if beta is None: - beta = 14.769656459379492 - beta_tensor = torch.tensor(float(beta)) - window = torch.i0(beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta_tensor) - t *= math.pi - kernel = torch.where(t == 0, torch.tensor(1.0).to(t), torch.sin(t) / t) - kernel.mul_(window) - kernels.append(kernel) - - scale = base_freq / orig_freq - kernels = torch.stack(kernels).view(new_freq, 1, -1).mul_(scale) - if dtype is None: - kernels = kernels.to(dtype=torch.float32) return kernels, width From 1db6103c7a33959ff7bba969e8a52688e3e74077 Mon Sep 17 00:00:00 2001 From: xvdp Date: Thu, 26 May 2022 18:45:53 -0700 Subject: [PATCH 2/5] fix to issue #2414, kernel creation uses loops. Changes: 1. loops are replaced with broadcasting 2. device is cast at end of kernel creation 3. dtype is defaulted to float32 ( float64) does nothing valuable 4. defaults for beta and dtype are moved to function declaration From f0f1b62b0753c1ec074bd61a4e2cf0ee5d0f6794 Mon Sep 17 00:00:00 2001 From: Sean Kim Date: Tue, 19 Jul 2022 13:27:29 -0400 Subject: [PATCH 3/5] Changes to previous commit --- torchaudio/functional/functional.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index ba0f478def..69b37b3549 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -1454,6 +1454,8 @@ def _get_sinc_resample_kernel( t.mul_(torch.pi) window = t.div(2*lowpass_filter_width).cos().pow_(2.) else: # kaiser window + if beta is None: + beta = 14.769656459379492 beta_tensor = torch.as_tensor(beta, dtype=dtype) window = beta_tensor.mul((1 - t.div(lowpass_filter_width).pow(2.)).sqrt()).div(beta_tensor.i0()) t.mul_(torch.pi) From 2b48d4ca68df616048bd0dc12e80471de464d17b Mon Sep 17 00:00:00 2001 From: Sean Kim Date: Tue, 19 Jul 2022 17:00:11 -0400 Subject: [PATCH 4/5] modify pr #2415 to improve resample kernel generation --- torchaudio/functional/functional.py | 45 +++++++++++++++++------------ 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 69b37b3549..f8369eefe4 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -1394,7 +1394,7 @@ def _get_sinc_resample_kernel( resampling_method: str = "sinc_interpolation", beta: Optional[float] = None, device: torch.device = torch.device("cpu"), - dtype: Optional[torch.dtype] = torch.float32, + dtype: Optional[torch.dtype] = None, ): if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq): @@ -1414,11 +1414,12 @@ def _get_sinc_resample_kernel( new_freq = int(new_freq) // gcd assert lowpass_filter_width > 0 - base_freq = min(orig_freq, new_freq) * rolloff + base_freq = min(orig_freq, new_freq) # This will perform antialiasing filtering by removing the highest frequencies. # At first I thought I only needed this when downsampling, but when upsampling # you will get edge artifacts without this, as the edge is equivalent to zero padding, # which will add high freq artifacts. + base_freq *= rolloff # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor) # using the sinc interpolation formula: @@ -1437,32 +1438,38 @@ def _get_sinc_resample_kernel( # = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq)) # so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`. # This will explain the F.conv1d after, with a stride of orig_freq. - - scale = base_freq / orig_freq - width = math.ceil(lowpass_filter_width / scale) + width = math.ceil(lowpass_filter_width * orig_freq / base_freq) # If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e., # they will have a lot of almost zero values to the left or to the right... # There is probably a way to evaluate those filters more efficiently, but this is kept for # future work. + idx_dtype = dtype if dtype is not None else torch.float64 + + idx = torch.arange(-width, width + orig_freq, dtype=idx_dtype)[None, None] / orig_freq - one = torch.tensor(1.0, dtype=dtype) - idx = torch.arange(-width, width + orig_freq, dtype=dtype)[None, None].div(orig_freq) - t = torch.arange(0, -new_freq, -1, dtype=dtype)[:, None, None].div(new_freq) + idx - t.mul_(base_freq).clamp_(-lowpass_filter_width, lowpass_filter_width) + t = torch.arange(0, -new_freq, -1, dtype=dtype)[:, None, None] / new_freq + idx + t *= base_freq + t = t.clamp_(-lowpass_filter_width, lowpass_filter_width) + + scale = base_freq / orig_freq if resampling_method == "sinc_interpolation": - t.mul_(torch.pi) - window = t.div(2*lowpass_filter_width).cos().pow_(2.) - else: # kaiser window + window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2 + else: + # kaiser_window if beta is None: beta = 14.769656459379492 - beta_tensor = torch.as_tensor(beta, dtype=dtype) - window = beta_tensor.mul((1 - t.div(lowpass_filter_width).pow(2.)).sqrt()).div(beta_tensor.i0()) - t.mul_(torch.pi) - - kernels = torch.where(t == 0, one, t.sin().div(t)) - kernels.mul_(window) - kernels.mul_(scale).to(device=device) + beta_tensor = torch.tensor(float(beta)) + window = torch.i0(beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta_tensor) + + t *= torch.pi + + kernels = torch.where(t == 0, torch.tensor(1.0).to(t), t.sin() / t) + kernels *= window * scale + kernels.to(device) + + if dtype is None: + kernels = kernels.to(dtype=torch.float32) return kernels, width From 715c4b0ba87b6f286c042e4f5e1bbfd7ac9c7410 Mon Sep 17 00:00:00 2001 From: Sean Kim Date: Tue, 19 Jul 2022 23:15:10 -0400 Subject: [PATCH 5/5] modifications addressing comments, will benchmark and possibly change after merge --- torchaudio/functional/functional.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index f8369eefe4..007fc2ba8e 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -1445,14 +1445,14 @@ def _get_sinc_resample_kernel( # future work. idx_dtype = dtype if dtype is not None else torch.float64 - idx = torch.arange(-width, width + orig_freq, dtype=idx_dtype)[None, None] / orig_freq + idx = torch.arange(-width, width + orig_freq, dtype=idx_dtype, device=device)[None, None] / orig_freq t = torch.arange(0, -new_freq, -1, dtype=dtype)[:, None, None] / new_freq + idx t *= base_freq t = t.clamp_(-lowpass_filter_width, lowpass_filter_width) - scale = base_freq / orig_freq - + # we do not use built in torch windows here as we need to evaluate the window + # at specific positions, not over a regular grid. if resampling_method == "sinc_interpolation": window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2 else: @@ -1462,11 +1462,11 @@ def _get_sinc_resample_kernel( beta_tensor = torch.tensor(float(beta)) window = torch.i0(beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta_tensor) - t *= torch.pi + t *= math.pi + scale = base_freq / orig_freq kernels = torch.where(t == 0, torch.tensor(1.0).to(t), t.sin() / t) kernels *= window * scale - kernels.to(device) if dtype is None: kernels = kernels.to(dtype=torch.float32)