-
Notifications
You must be signed in to change notification settings - Fork 89
/
convolution.py
81 lines (68 loc) · 2.74 KB
/
convolution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
from torch_audiomentations.utils.fft import rfft, irfft
_NEXT_FAST_LEN = {}
def next_fast_len(size):
"""
Returns the next largest number ``n >= size`` whose prime factors are all
2, 3, or 5. These sizes are efficient for fast fourier transforms.
Equivalent to :func:`scipy.fftpack.next_fast_len`.
Note: This function was originally copied from the https://github.com/pyro-ppl/pyro
repository, where the license was Apache 2.0. Any modifications to the original code can be
found at https://github.com/asteroid-team/torch-audiomentations/commits
:param int size: A positive number.
:returns: A possibly larger number.
:rtype int:
"""
try:
return _NEXT_FAST_LEN[size]
except KeyError:
pass
assert isinstance(size, int) and size > 0
next_size = size
while True:
remaining = next_size
for n in (2, 3, 5):
while remaining % n == 0:
remaining //= n
if remaining == 1:
_NEXT_FAST_LEN[size] = next_size
return next_size
next_size += 1
def convolve(signal, kernel, mode="full"):
"""
Computes the 1-d convolution of signal by kernel using FFTs.
The two arguments should have the same rightmost dim, but may otherwise be
arbitrarily broadcastable.
Note: This function was originally copied from the https://github.com/pyro-ppl/pyro
repository, where the license was Apache 2.0. Any modifications to the original code can be
found at https://github.com/asteroid-team/torch-audiomentations/commits
:param torch.Tensor signal: A signal to convolve.
:param torch.Tensor kernel: A convolution kernel.
:param str mode: One of: 'full', 'valid', 'same'.
:return: A tensor with broadcasted shape. Letting ``m = signal.size(-1)``
and ``n = kernel.size(-1)``, the rightmost size of the result will be:
``m + n - 1`` if mode is 'full';
``max(m, n) - min(m, n) + 1`` if mode is 'valid'; or
``max(m, n)`` if mode is 'same'.
:rtype torch.Tensor:
"""
m = signal.size(-1)
n = kernel.size(-1)
if mode == "full":
truncate = m + n - 1
elif mode == "valid":
truncate = max(m, n) - min(m, n) + 1
elif mode == "same":
truncate = max(m, n)
else:
raise ValueError("Unknown mode: {}".format(mode))
# Compute convolution using fft.
padded_size = m + n - 1
# Round up for cheaper fft.
fast_ftt_size = next_fast_len(padded_size)
f_signal = rfft(signal, n=fast_ftt_size)
f_kernel = rfft(kernel, n=fast_ftt_size)
f_result = f_signal * f_kernel
result = irfft(f_result, n=fast_ftt_size)
start_idx = (padded_size - truncate) // 2
return result[..., start_idx : start_idx + truncate]