Skip to content

Commit

Permalink
Add option to use any window function available in scipy.signal
Browse files Browse the repository at this point in the history
This does not change the default behavior or other existing code, since the torch windows can still be accessed using the original arguments.
  • Loading branch information
simonschwaer committed Dec 19, 2023
1 parent 1576b0c commit ffe2db7
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
9 changes: 5 additions & 4 deletions auraloss/freq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from typing import List, Any

from .utils import apply_reduction
from .utils import apply_reduction, get_window
from .perceptual import SumAndDifference, FIRFilter


Expand Down Expand Up @@ -58,8 +58,9 @@ class STFTLoss(torch.nn.Module):
fft_size (int, optional): FFT size in samples. Default: 1024
hop_size (int, optional): Hop size of the FFT in samples. Default: 256
win_length (int, optional): Length of the FFT analysis window. Default: 1024
window (str, optional): Window to apply before FFT, options include:
['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
window (str, optional): Window to apply before FFT, can either be one of the window function provided in PyTorch
['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html).
Default: 'hann_window'
w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0
w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0
Expand Down Expand Up @@ -117,7 +118,7 @@ def __init__(
self.fft_size = fft_size
self.hop_size = hop_size
self.win_length = win_length
self.window = getattr(torch, window)(win_length)
self.window = get_window(window, win_length)
self.w_sc = w_sc
self.w_log_mag = w_log_mag
self.w_lin_mag = w_lin_mag
Expand Down
21 changes: 21 additions & 0 deletions auraloss/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import scipy.signal


def apply_reduction(losses, reduction="none"):
Expand All @@ -8,3 +9,23 @@ def apply_reduction(losses, reduction="none"):
elif reduction == "sum":
losses = losses.sum()
return losses

def get_window(win_type: str, win_length: int):
"""Return a window function.
Args:
win_type (str): Window type. Can either be one of the window function provided in PyTorch
['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html).
win_length (int): Window length
Returns:
win: The window as a 1D torch tensor
"""

try:
win = getattr(torch, win_type)(win_length)
except:
win = torch.from_numpy(scipy.signal.windows.get_window(win_type, win_length))

return win

0 comments on commit ffe2db7

Please sign in to comment.