Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small changes for flexibility and performance #68

Merged
merged 3 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions auraloss/freq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from typing import List, Any

from .utils import apply_reduction
from .utils import apply_reduction, get_window
from .perceptual import SumAndDifference, FIRFilter


Expand All @@ -25,16 +25,29 @@ class STFTMagnitudeLoss(torch.nn.Module):
See [Arik et al., 2018](https://arxiv.org/abs/1808.06719)
and [Engel et al., 2020](https://arxiv.org/abs/2001.04643v1)

Log-magnitudes are calculated with `log(log_fac*x + log_eps)`, where `log_fac` controls the
compression strength (larger value results in more compression), and `log_eps` can be used
to control the range of the compressed output values (e.g., `log_eps>=1` ensures positive
output values). The default values `log_fac=1` and `log_eps=0` correspond to plain log-compression.

Args:
log (bool, optional): Log-scale the STFT magnitudes,
or use linear scale. Default: True
log_eps (float, optional): Constant value added to the magnitudes before evaluating the logarithm.
Default: 0.0
log_fac (float, optional): Constant multiplication factor for the magnitudes before evaluating the logarithm.
Default: 1.0
distance (str, optional): Distance function ["L1", "L2"]. Default: "L1"
reduction (str, optional): Reduction of the loss elements. Default: "mean"
"""

def __init__(self, log=True, distance="L1", reduction="mean"):
def __init__(self, log=True, log_eps=0.0, log_fac=1.0, distance="L1", reduction="mean"):
super(STFTMagnitudeLoss, self).__init__()

self.log = log
self.log_eps = log_eps
self.log_fac = log_fac

if distance == "L1":
self.distance = torch.nn.L1Loss(reduction=reduction)
elif distance == "L2":
Expand All @@ -44,8 +57,8 @@ def __init__(self, log=True, distance="L1", reduction="mean"):

def forward(self, x_mag, y_mag):
if self.log:
x_mag = torch.log(x_mag)
y_mag = torch.log(y_mag)
x_mag = torch.log(self.log_fac * x_mag + self.log_eps)
y_mag = torch.log(self.log_fac * y_mag + self.log_eps)
return self.distance(x_mag, y_mag)


Expand All @@ -58,8 +71,9 @@ class STFTLoss(torch.nn.Module):
fft_size (int, optional): FFT size in samples. Default: 1024
hop_size (int, optional): Hop size of the FFT in samples. Default: 256
win_length (int, optional): Length of the FFT analysis window. Default: 1024
window (str, optional): Window to apply before FFT, options include:
['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
window (str, optional): Window to apply before FFT, can either be one of the window function provided in PyTorch
['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html).
Default: 'hann_window'
w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0
w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0
Expand Down Expand Up @@ -112,12 +126,13 @@ def __init__(
reduction: str = "mean",
mag_distance: str = "L1",
device: Any = None,
**kwargs
):
super().__init__()
self.fft_size = fft_size
self.hop_size = hop_size
self.win_length = win_length
self.window = getattr(torch, window)(win_length)
self.window = get_window(window, win_length)
self.w_sc = w_sc
self.w_log_mag = w_log_mag
self.w_lin_mag = w_lin_mag
Expand All @@ -133,16 +148,20 @@ def __init__(
self.mag_distance = mag_distance
self.device = device

self.phs_used = bool(self.w_phs)

self.spectralconv = SpectralConvergenceLoss()
self.logstft = STFTMagnitudeLoss(
log=True,
reduction=reduction,
distance=mag_distance,
**kwargs
)
self.linstft = STFTMagnitudeLoss(
log=False,
reduction=reduction,
distance=mag_distance,
**kwargs
)

# setup mel filterbank
Expand Down Expand Up @@ -203,7 +222,13 @@ def stft(self, x):
x_mag = torch.sqrt(
torch.clamp((x_stft.real**2) + (x_stft.imag**2), min=self.eps)
)
x_phs = torch.angle(x_stft)

# torch.angle is expensive, so it is only evaluated if the values are used in the loss
if self.phs_used:
x_phs = torch.angle(x_stft)
else:
x_phs = None

return x_mag, x_phs

def forward(self, input: torch.Tensor, target: torch.Tensor):
Expand All @@ -224,6 +249,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor):

# compute the magnitude and phase spectra of input and target
self.window = self.window.to(input.device)

x_mag, x_phs = self.stft(input.view(-1, input.size(-1)))
y_mag, y_phs = self.stft(target.view(-1, target.size(-1)))

Expand All @@ -242,7 +268,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor):
sc_mag_loss = self.spectralconv(x_mag, y_mag) if self.w_sc else 0.0
log_mag_loss = self.logstft(x_mag, y_mag) if self.w_log_mag else 0.0
lin_mag_loss = self.linstft(x_mag, y_mag) if self.w_lin_mag else 0.0
phs_loss = torch.nn.functional.mse_loss(x_phs, y_phs) if self.w_phs else 0.0
phs_loss = torch.nn.functional.mse_loss(x_phs, y_phs) if self.phs_used else 0.0

# combine loss terms
loss = (
Expand Down
21 changes: 21 additions & 0 deletions auraloss/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import scipy.signal


def apply_reduction(losses, reduction="none"):
Expand All @@ -8,3 +9,23 @@ def apply_reduction(losses, reduction="none"):
elif reduction == "sum":
losses = losses.sum()
return losses

def get_window(win_type: str, win_length: int):
"""Return a window function.

Args:
win_type (str): Window type. Can either be one of the window function provided in PyTorch
['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html).
win_length (int): Window length

Returns:
win: The window as a 1D torch tensor
"""

try:
win = getattr(torch, win_type)(win_length)
except:
win = torch.from_numpy(scipy.signal.windows.get_window(win_type, win_length))

return win