Skip to content

Commit

Permalink
Allow for more flexible log compression
Browse files Browse the repository at this point in the history
By using `log(a*x+e)`, it is possible to control the amount of compression and the output value range. This does not change default behaviour.
  • Loading branch information
simonschwaer committed Dec 19, 2023
1 parent 0d1aa27 commit bba3748
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions auraloss/freq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,29 @@ class STFTMagnitudeLoss(torch.nn.Module):
See [Arik et al., 2018](https://arxiv.org/abs/1808.06719)
and [Engel et al., 2020](https://arxiv.org/abs/2001.04643v1)
Log-magnitudes are calculated with `log(log_fac*x + log_eps)`, where `log_fac` controls the
compression strength (larger value results in more compression), and `log_eps` can be used
to control the range of the compressed output values (e.g., `log_eps>=1` ensures positive
output values). The default values `log_fac=1` and `log_eps=0` correspond to plain log-compression.
Args:
log (bool, optional): Log-scale the STFT magnitudes,
or use linear scale. Default: True
log_eps (float, optional): Constant value added to the magnitudes before evaluating the logarithm.
Default: 0.0
log_fac (float, optional): Constant multiplication factor for the magnitudes before evaluating the logarithm.
Default: 1.0
distance (str, optional): Distance function ["L1", "L2"]. Default: "L1"
reduction (str, optional): Reduction of the loss elements. Default: "mean"
"""

def __init__(self, log=True, distance="L1", reduction="mean"):
def __init__(self, log=True, log_eps=0.0, log_fac=1.0, distance="L1", reduction="mean"):
super(STFTMagnitudeLoss, self).__init__()

self.log = log
self.log_eps = log_eps
self.log_fac = log_fac

if distance == "L1":
self.distance = torch.nn.L1Loss(reduction=reduction)
elif distance == "L2":
Expand All @@ -44,8 +57,8 @@ def __init__(self, log=True, distance="L1", reduction="mean"):

def forward(self, x_mag, y_mag):
if self.log:
x_mag = torch.log(x_mag)
y_mag = torch.log(y_mag)
x_mag = torch.log(self.log_fac * x_mag + self.log_eps)
y_mag = torch.log(self.log_fac * y_mag + self.log_eps)
return self.distance(x_mag, y_mag)


Expand Down Expand Up @@ -113,6 +126,7 @@ def __init__(
reduction: str = "mean",
mag_distance: str = "L1",
device: Any = None,
**kwargs
):
super().__init__()
self.fft_size = fft_size
Expand All @@ -139,11 +153,13 @@ def __init__(
log=True,
reduction=reduction,
distance=mag_distance,
**kwargs
)
self.linstft = STFTMagnitudeLoss(
log=False,
reduction=reduction,
distance=mag_distance,
**kwargs
)

# setup mel filterbank
Expand Down

0 comments on commit bba3748

Please sign in to comment.