forked from MaxPayne86/CoreAudioML
-
Notifications
You must be signed in to change notification settings - Fork 0
/
training.py
executable file
·145 lines (119 loc) · 5.81 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import torch.nn as nn
import torch.nn.functional as F
# ESR loss calculates the Error-to-signal between the output/target
class ESRLoss(nn.Module):
def __init__(self):
super(ESRLoss, self).__init__()
self.epsilon = 0.00001
def forward(self, output, target):
loss = torch.add(target, -output)
loss = torch.pow(loss, 2)
loss = torch.mean(loss)
energy = torch.mean(torch.pow(target, 2)) + self.epsilon
loss = torch.div(loss, energy)
return loss
class DCLoss(nn.Module):
def __init__(self):
super(DCLoss, self).__init__()
self.epsilon = 0.00001
def forward(self, output, target):
loss = torch.pow(torch.add(torch.mean(target, 0), -torch.mean(output, 0)), 2)
loss = torch.mean(loss)
energy = torch.mean(torch.pow(target, 2)) + self.epsilon
loss = torch.div(loss, energy)
return loss
# ESR loss calculates the Error-to-signal between the output/target
class MultiSpecLoss(nn.Module):
def __init__(self, fft_sizes=(2048, 1024, 512, 256, 128)):
super(MultiSpecLoss, self).__init__()
self.epsilon = 0.00001
self.fft_sizes = fft_sizes
self.spec_loss = []
for size in self.fft_sizes:
hop = size//4
self.spec_loss.append(SpecLoss(size, hop))
def forward(self, output, target):
output = output.squeeze()
target = target.squeeze()
total_loss = 0
for item in self.spec_loss:
total_loss += item(output, target)
return total_loss/len(self.fft_sizes)
class SpecLoss(nn.Module):
def __init__(self, fft_size=512, hop_size=128):
super(SpecLoss, self).__init__()
self.epsilon = 0.00001
self.fft_size = fft_size
self.hop_size = hop_size
def forward(self, output, target):
magx = torch.abs(torch.stft(output, n_fft=self.fft_size, hop_length=self.hop_size, return_complex=True))
magy = torch.abs(torch.stft(target, n_fft=self.fft_size, hop_length=self.hop_size, return_complex=True))
logx = torch.log(torch.where(magx <= self.epsilon, torch.Tensor([self.epsilon]).to(output.device), magx))
logy = torch.log(torch.where(magy <= self.epsilon, torch.Tensor([self.epsilon]).to(output.device), magy))
return F.l1_loss(magx, magy) + F.l1_loss(logx, logy)
# PreEmph is a class that applies an FIR pre-emphasis filter to the signal, the filter coefficients are in the
# filter_cfs argument, and lp is a flag that also applies a low pass filter
# Only supported for single-channel!
class PreEmph(nn.Module):
def __init__(self, filter_cfs, low_pass=0):
super(PreEmph, self).__init__()
self.epsilon = 0.00001
self.zPad = len(filter_cfs) - 1
self.conv_filter = nn.Conv1d(1, 1, 2, bias=False)
self.conv_filter.weight.data = torch.tensor([[filter_cfs]], requires_grad=False)
self.low_pass = low_pass
if self.low_pass:
self.lp_filter = nn.Conv1d(1, 1, 2, bias=False)
self.lp_filter.weight.data = torch.tensor([[[0.85, 1]]], requires_grad=False)
def forward(self, output, target):
# zero pad the input/target so the filtered signal is the same length
output = torch.cat((torch.zeros(self.zPad, output.shape[1], 1), output))
target = torch.cat((torch.zeros(self.zPad, target.shape[1], 1), target))
# Apply pre-emph filter, permute because the dimension order is different for RNNs and Convs in pytorch...
output = self.conv_filter(output.permute(1, 2, 0))
target = self.conv_filter(target.permute(1, 2, 0))
if self.low_pass:
output = self.lp_filter(output)
target = self.lp_filter(target)
return output.permute(2, 0, 1), target.permute(2, 0, 1)
class LossWrapper(nn.Module):
def __init__(self, losses, pre_filt=None):
super(LossWrapper, self).__init__()
loss_dict = {'ESR': ESRLoss(), 'DC': DCLoss()}
if pre_filt:
pre_filt = PreEmph(pre_filt)
loss_dict['ESRPre'] = lambda output, target: loss_dict['ESR'].forward(*pre_filt(output, target))
loss_functions = [[loss_dict[key], value] for key, value in losses.items()]
self.loss_functions = tuple([items[0] for items in loss_functions])
try:
self.loss_factors = tuple(torch.Tensor([items[1] for items in loss_functions]))
except IndexError:
self.loss_factors = torch.ones(len(self.loss_functions))
def forward(self, output, target):
loss = 0
for i, losses in enumerate(self.loss_functions):
loss += torch.mul(losses(output, target), self.loss_factors[i])
return loss
class TrainTrack(dict):
def __init__(self):
self.update({'current_epoch': 0, 'training_losses': [], 'validation_losses': [], 'train_av_time': 0.0,
'val_av_time': 0.0, 'total_time': 0.0, 'best_val_loss': 1e12, 'test_loss': 0})
def restore_data(self, training_info):
self.update(training_info)
def train_epoch_update(self, loss, ep_st_time, ep_end_time, init_time, current_ep):
if self['train_av_time']:
self['train_av_time'] = (self['train_av_time'] + ep_end_time - ep_st_time) / 2
else:
self['train_av_time'] = ep_end_time - ep_st_time
self['training_losses'].append(loss)
self['current_epoch'] = current_ep
self['total_time'] += ((init_time + ep_end_time - ep_st_time)/3600)
def val_epoch_update(self, loss, ep_st_time, ep_end_time):
if self['val_av_time']:
self['val_av_time'] = (self['val_av_time'] + ep_end_time - ep_st_time) / 2
else:
self['val_av_time'] = ep_end_time - ep_st_time
self['validation_losses'].append(loss)
if loss < self['best_val_loss']:
self['best_val_loss'] = loss