forked from sigsep/open-unmix-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
294 lines (249 loc) · 10 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import argparse
import model
import data
import torch
import time
from pathlib import Path
import tqdm
import json
import utils
import sklearn.preprocessing
import numpy as np
import random
from git import Repo
import os
import copy
tqdm.monitor_interval = 0
def train(args, unmix, device, train_sampler, optimizer):
losses = utils.AverageMeter()
unmix.train()
pbar = tqdm.tqdm(train_sampler, disable=args.quiet)
for x, y in pbar:
pbar.set_description("Training batch")
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
Y_hat = unmix(x)
Y = unmix.transform(y)
loss = torch.nn.functional.mse_loss(Y_hat, Y)
loss.backward()
optimizer.step()
losses.update(loss.item(), Y.size(1))
return losses.avg
def valid(args, unmix, device, valid_sampler):
losses = utils.AverageMeter()
unmix.eval()
with torch.no_grad():
for x, y in valid_sampler:
x, y = x.to(device), y.to(device)
Y_hat = unmix(x)
Y = unmix.transform(y)
loss = torch.nn.functional.mse_loss(Y_hat, Y)
losses.update(loss.item(), Y.size(1))
return losses.avg
def get_statistics(args, dataset):
scaler = sklearn.preprocessing.StandardScaler()
spec = torch.nn.Sequential(
model.STFT(n_fft=args.nfft, n_hop=args.nhop),
model.Spectrogram(mono=True)
)
dataset_scaler = copy.deepcopy(dataset)
dataset_scaler.samples_per_track = 1
dataset_scaler.augmentations = None
dataset_scaler.random_chunks = False
dataset_scaler.random_track_mix = False
dataset_scaler.random_interferer_mix = False
dataset_scaler.seq_duration = None
pbar = tqdm.tqdm(range(len(dataset_scaler)), disable=args.quiet)
for ind in pbar:
x, y = dataset_scaler[ind]
pbar.set_description("Compute dataset statistics")
X = spec(x[None, ...])
scaler.partial_fit(np.squeeze(X))
# set inital input scaler values
std = np.maximum(
scaler.scale_,
1e-4*np.max(scaler.scale_)
)
return scaler.mean_, std
def main():
parser = argparse.ArgumentParser(description='Open Unmix Trainer')
# which target do we want to train?
parser.add_argument('--target', type=str, default='vocals',
help='target source (will be passed to the dataset)')
# Dataset paramaters
parser.add_argument('--dataset', type=str, default="musdb",
choices=[
'musdb', 'aligned', 'sourcefolder',
'trackfolder_var', 'trackfolder_fix'
],
help='Name of the dataset.')
parser.add_argument('--root', type=str, help='root path of dataset')
parser.add_argument('--output', type=str, default="open-unmix",
help='provide output path base folder name')
parser.add_argument('--model', type=str, help='Path to checkpoint folder')
# Trainig Parameters
parser.add_argument('--epochs', type=int, default=1000)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--lr', type=float, default=0.001,
help='learning rate, defaults to 1e-3')
parser.add_argument('--patience', type=int, default=140,
help='maximum number of epochs to train (default: 140)')
parser.add_argument('--lr-decay-patience', type=int, default=80,
help='lr decay patience for plateau scheduler')
parser.add_argument('--lr-decay-gamma', type=float, default=0.3,
help='gamma of learning rate scheduler decay')
parser.add_argument('--weight-decay', type=float, default=0.00001,
help='weight decay')
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
# Model Parameters
parser.add_argument('--seq-dur', type=float, default=6.0,
help='Sequence duration in seconds'
'value of <=0.0 will use full/variable length')
parser.add_argument('--unidirectional', action='store_true', default=False,
help='Use unidirectional LSTM instead of bidirectional')
parser.add_argument('--nfft', type=int, default=4096,
help='STFT fft size and window size')
parser.add_argument('--nhop', type=int, default=1024,
help='STFT hop size')
parser.add_argument('--hidden-size', type=int, default=512,
help='hidden size parameter of dense bottleneck layers')
parser.add_argument('--bandwidth', type=int, default=16000,
help='maximum model bandwidth in herz')
parser.add_argument('--nb-channels', type=int, default=2,
help='set number of channels for model (1, 2)')
parser.add_argument('--nb-workers', type=int, default=0,
help='Number of workers for dataloader.')
# Misc Parameters
parser.add_argument('--quiet', action='store_true', default=False,
help='less verbose during training')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
args, _ = parser.parse_known_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
print("Using GPU:", use_cuda)
print("Using Torchaudio: ", utils._torchaudio_available())
dataloader_kwargs = {'num_workers': args.nb_workers, 'pin_memory': True} if use_cuda else {}
repo_dir = os.path.abspath(os.path.dirname(__file__))
repo = Repo(repo_dir)
commit = repo.head.commit.hexsha[:7]
# use jpg or npy
torch.manual_seed(args.seed)
random.seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
train_dataset, valid_dataset, args = data.load_datasets(parser, args)
# create output dir if not exist
target_path = Path(args.output)
target_path.mkdir(parents=True, exist_ok=True)
train_sampler = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True,
**dataloader_kwargs
)
valid_sampler = torch.utils.data.DataLoader(
valid_dataset, batch_size=1,
**dataloader_kwargs
)
if args.model:
scaler_mean = None
scaler_std = None
else:
scaler_mean, scaler_std = get_statistics(args, train_dataset)
max_bin = utils.bandwidth_to_max_bin(
train_dataset.sample_rate, args.nfft, args.bandwidth
)
unmix = model.OpenUnmix(
input_mean=scaler_mean,
input_scale=scaler_std,
nb_channels=args.nb_channels,
hidden_size=args.hidden_size,
n_fft=args.nfft,
n_hop=args.nhop,
max_bin=max_bin,
sample_rate=train_dataset.sample_rate
).to(device)
optimizer = torch.optim.Adam(
unmix.parameters(),
lr=args.lr,
weight_decay=args.weight_decay
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
factor=args.lr_decay_gamma,
patience=args.lr_decay_patience,
cooldown=10
)
es = utils.EarlyStopping(patience=args.patience)
# if a model is specified: resume training
if args.model:
model_path = Path(args.model).expanduser()
with open(Path(model_path, args.target + '.json'), 'r') as stream:
results = json.load(stream)
target_model_path = Path(model_path, args.target + ".chkpnt")
checkpoint = torch.load(target_model_path, map_location=device)
unmix.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
# train for another epochs_trained
t = tqdm.trange(
results['epochs_trained'],
results['epochs_trained'] + args.epochs + 1,
disable=args.quiet
)
train_losses = results['train_loss_history']
valid_losses = results['valid_loss_history']
train_times = results['train_time_history']
best_epoch = results['best_epoch']
es.best = results['best_loss']
es.num_bad_epochs = results['num_bad_epochs']
# else start from 0
else:
t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
train_losses = []
valid_losses = []
train_times = []
best_epoch = 0
for epoch in t:
t.set_description("Training Epoch")
end = time.time()
train_loss = train(args, unmix, device, train_sampler, optimizer)
valid_loss = valid(args, unmix, device, valid_sampler)
scheduler.step(valid_loss)
train_losses.append(train_loss)
valid_losses.append(valid_loss)
t.set_postfix(
train_loss=train_loss, val_loss=valid_loss
)
stop = es.step(valid_loss)
if valid_loss == es.best:
best_epoch = epoch
utils.save_checkpoint({
'epoch': epoch + 1,
'state_dict': unmix.state_dict(),
'best_loss': es.best,
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict()
},
is_best=valid_loss == es.best,
path=target_path,
target=args.target
)
# save params
params = {
'epochs_trained': epoch,
'args': vars(args),
'best_loss': es.best,
'best_epoch': best_epoch,
'train_loss_history': train_losses,
'valid_loss_history': valid_losses,
'train_time_history': train_times,
'num_bad_epochs': es.num_bad_epochs,
'commit': commit
}
with open(Path(target_path, args.target + '.json'), 'w') as outfile:
outfile.write(json.dumps(params, indent=4, sort_keys=True))
train_times.append(time.time() - end)
if stop:
print("Apply Early Stopping")
break
if __name__ == "__main__":
main()