-
Notifications
You must be signed in to change notification settings - Fork 111
/
loss.py
368 lines (316 loc) · 11.8 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
import typing
from typing import List
import torch
import torch.nn.functional as F
from audiotools import AudioSignal
from audiotools import STFTParams
from torch import nn
class L1Loss(nn.L1Loss):
"""L1 Loss between AudioSignals. Defaults
to comparing ``audio_data``, but any
attribute of an AudioSignal can be used.
Parameters
----------
attribute : str, optional
Attribute of signal to compare, defaults to ``audio_data``.
weight : float, optional
Weight of this loss, defaults to 1.0.
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
"""
def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
self.attribute = attribute
self.weight = weight
super().__init__(**kwargs)
def forward(self, x: AudioSignal, y: AudioSignal):
"""
Parameters
----------
x : AudioSignal
Estimate AudioSignal
y : AudioSignal
Reference AudioSignal
Returns
-------
torch.Tensor
L1 loss between AudioSignal attributes.
"""
if isinstance(x, AudioSignal):
x = getattr(x, self.attribute)
y = getattr(y, self.attribute)
return super().forward(x, y)
class SISDRLoss(nn.Module):
"""
Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
of estimated and reference audio signals or aligned features.
Parameters
----------
scaling : int, optional
Whether to use scale-invariant (True) or
signal-to-noise ratio (False), by default True
reduction : str, optional
How to reduce across the batch (either 'mean',
'sum', or none).], by default ' mean'
zero_mean : int, optional
Zero mean the references and estimates before
computing the loss, by default True
clip_min : int, optional
The minimum possible loss value. Helps network
to not focus on making already good examples better, by default None
weight : float, optional
Weight of this loss, defaults to 1.0.
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
"""
def __init__(
self,
scaling: int = True,
reduction: str = "mean",
zero_mean: int = True,
clip_min: int = None,
weight: float = 1.0,
):
self.scaling = scaling
self.reduction = reduction
self.zero_mean = zero_mean
self.clip_min = clip_min
self.weight = weight
super().__init__()
def forward(self, x: AudioSignal, y: AudioSignal):
eps = 1e-8
# nb, nc, nt
if isinstance(x, AudioSignal):
references = x.audio_data
estimates = y.audio_data
else:
references = x
estimates = y
nb = references.shape[0]
references = references.reshape(nb, 1, -1).permute(0, 2, 1)
estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
# samples now on axis 1
if self.zero_mean:
mean_reference = references.mean(dim=1, keepdim=True)
mean_estimate = estimates.mean(dim=1, keepdim=True)
else:
mean_reference = 0
mean_estimate = 0
_references = references - mean_reference
_estimates = estimates - mean_estimate
references_projection = (_references**2).sum(dim=-2) + eps
references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
scale = (
(references_on_estimates / references_projection).unsqueeze(1)
if self.scaling
else 1
)
e_true = scale * _references
e_res = _estimates - e_true
signal = (e_true**2).sum(dim=1)
noise = (e_res**2).sum(dim=1)
sdr = -10 * torch.log10(signal / noise + eps)
if self.clip_min is not None:
sdr = torch.clamp(sdr, min=self.clip_min)
if self.reduction == "mean":
sdr = sdr.mean()
elif self.reduction == "sum":
sdr = sdr.sum()
return sdr
class MultiScaleSTFTLoss(nn.Module):
"""Computes the multi-scale STFT loss from [1].
Parameters
----------
window_lengths : List[int], optional
Length of each window of each STFT, by default [2048, 512]
loss_fn : typing.Callable, optional
How to compare each loss, by default nn.L1Loss()
clamp_eps : float, optional
Clamp on the log magnitude, below, by default 1e-5
mag_weight : float, optional
Weight of raw magnitude portion of loss, by default 1.0
log_weight : float, optional
Weight of log magnitude portion of loss, by default 1.0
pow : float, optional
Power to raise magnitude to before taking log, by default 2.0
weight : float, optional
Weight of this loss, by default 1.0
match_stride : bool, optional
Whether to match the stride of convolutional layers, by default False
References
----------
1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
"DDSP: Differentiable Digital Signal Processing."
International Conference on Learning Representations. 2019.
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
"""
def __init__(
self,
window_lengths: List[int] = [2048, 512],
loss_fn: typing.Callable = nn.L1Loss(),
clamp_eps: float = 1e-5,
mag_weight: float = 1.0,
log_weight: float = 1.0,
pow: float = 2.0,
weight: float = 1.0,
match_stride: bool = False,
window_type: str = None,
):
super().__init__()
self.stft_params = [
STFTParams(
window_length=w,
hop_length=w // 4,
match_stride=match_stride,
window_type=window_type,
)
for w in window_lengths
]
self.loss_fn = loss_fn
self.log_weight = log_weight
self.mag_weight = mag_weight
self.clamp_eps = clamp_eps
self.weight = weight
self.pow = pow
def forward(self, x: AudioSignal, y: AudioSignal):
"""Computes multi-scale STFT between an estimate and a reference
signal.
Parameters
----------
x : AudioSignal
Estimate signal
y : AudioSignal
Reference signal
Returns
-------
torch.Tensor
Multi-scale STFT loss.
"""
loss = 0.0
for s in self.stft_params:
x.stft(s.window_length, s.hop_length, s.window_type)
y.stft(s.window_length, s.hop_length, s.window_type)
loss += self.log_weight * self.loss_fn(
x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
)
loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
return loss
class MelSpectrogramLoss(nn.Module):
"""Compute distance between mel spectrograms. Can be used
in a multi-scale way.
Parameters
----------
n_mels : List[int]
Number of mels per STFT, by default [150, 80],
window_lengths : List[int], optional
Length of each window of each STFT, by default [2048, 512]
loss_fn : typing.Callable, optional
How to compare each loss, by default nn.L1Loss()
clamp_eps : float, optional
Clamp on the log magnitude, below, by default 1e-5
mag_weight : float, optional
Weight of raw magnitude portion of loss, by default 1.0
log_weight : float, optional
Weight of log magnitude portion of loss, by default 1.0
pow : float, optional
Power to raise magnitude to before taking log, by default 2.0
weight : float, optional
Weight of this loss, by default 1.0
match_stride : bool, optional
Whether to match the stride of convolutional layers, by default False
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
"""
def __init__(
self,
n_mels: List[int] = [150, 80],
window_lengths: List[int] = [2048, 512],
loss_fn: typing.Callable = nn.L1Loss(),
clamp_eps: float = 1e-5,
mag_weight: float = 1.0,
log_weight: float = 1.0,
pow: float = 2.0,
weight: float = 1.0,
match_stride: bool = False,
mel_fmin: List[float] = [0.0, 0.0],
mel_fmax: List[float] = [None, None],
window_type: str = None,
):
super().__init__()
self.stft_params = [
STFTParams(
window_length=w,
hop_length=w // 4,
match_stride=match_stride,
window_type=window_type,
)
for w in window_lengths
]
self.n_mels = n_mels
self.loss_fn = loss_fn
self.clamp_eps = clamp_eps
self.log_weight = log_weight
self.mag_weight = mag_weight
self.weight = weight
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.pow = pow
def forward(self, x: AudioSignal, y: AudioSignal):
"""Computes mel loss between an estimate and a reference
signal.
Parameters
----------
x : AudioSignal
Estimate signal
y : AudioSignal
Reference signal
Returns
-------
torch.Tensor
Mel loss.
"""
loss = 0.0
for n_mels, fmin, fmax, s in zip(
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
):
kwargs = {
"window_length": s.window_length,
"hop_length": s.hop_length,
"window_type": s.window_type,
}
x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
loss += self.log_weight * self.loss_fn(
x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
)
loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
return loss
class GANLoss(nn.Module):
"""
Computes a discriminator loss, given a discriminator on
generated waveforms/spectrograms compared to ground truth
waveforms/spectrograms. Computes the loss for both the
discriminator and the generator in separate functions.
"""
def __init__(self, discriminator):
super().__init__()
self.discriminator = discriminator
def forward(self, fake, real):
d_fake = self.discriminator(fake.audio_data)
d_real = self.discriminator(real.audio_data)
return d_fake, d_real
def discriminator_loss(self, fake, real):
d_fake, d_real = self.forward(fake.clone().detach(), real)
loss_d = 0
for x_fake, x_real in zip(d_fake, d_real):
loss_d += torch.mean(x_fake[-1] ** 2)
loss_d += torch.mean((1 - x_real[-1]) ** 2)
return loss_d
def generator_loss(self, fake, real):
d_fake, d_real = self.forward(fake, real)
loss_g = 0
for x_fake in d_fake:
loss_g += torch.mean((1 - x_fake[-1]) ** 2)
loss_feature = 0
for i in range(len(d_fake)):
for j in range(len(d_fake[i]) - 1):
loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
return loss_g, loss_feature