-
Notifications
You must be signed in to change notification settings - Fork 66
/
train_Tasnet.py
91 lines (75 loc) · 3.47 KB
/
train_Tasnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import sys
sys.path.append('./')
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader as Loader
from data_loader.Dataset import Datasets
from model import model
from logger import set_logger
import logging
from config import option
import argparse
import torch
from trainer import trainer_Tasnet
def make_dataloader(opt):
# make train's dataloader
train_dataset = Datasets(
opt['datasets']['train']['dataroot_mix'],
[opt['datasets']['train']['dataroot_targets'][0],
opt['datasets']['train']['dataroot_targets'][1]],
**opt['datasets']['audio_setting'])
train_dataloader = Loader(train_dataset,
batch_size=opt['datasets']['dataloader_setting']['batch_size'],
num_workers=opt['datasets']['dataloader_setting']['num_workers'],
shuffle=opt['datasets']['dataloader_setting']['shuffle'])
# make validation dataloader
val_dataset = Datasets(
opt['datasets']['val']['dataroot_mix'],
[opt['datasets']['val']['dataroot_targets'][0],
opt['datasets']['val']['dataroot_targets'][1]],
**opt['datasets']['audio_setting'])
val_dataloader = Loader(val_dataset,
batch_size=opt['datasets']['dataloader_setting']['batch_size'],
num_workers=opt['datasets']['dataloader_setting']['num_workers'],
shuffle=opt['datasets']['dataloader_setting']['shuffle'])
return train_dataloader, val_dataloader
def make_optimizer(params, opt):
optimizer = getattr(torch.optim, opt['optim']['name'])
if opt['optim']['name'] == 'Adam':
optimizer = optimizer(
params, lr=opt['optim']['lr'], weight_decay=opt['optim']['weight_decay'])
else:
optimizer = optimizer(params, lr=opt['optim']['lr'], weight_decay=opt['optim']
['weight_decay'], momentum=opt['optim']['momentum'])
return optimizer
def train():
parser = argparse.ArgumentParser(
description='Parameters for training Conv-TasNet')
parser.add_argument('--opt', type=str, help='Path to option YAML file.')
args = parser.parse_args()
opt = option.parse(args.opt)
set_logger.setup_logger(opt['logger']['name'], opt['logger']['path'],
screen=opt['logger']['screen'], tofile=opt['logger']['tofile'])
logger = logging.getLogger(opt['logger']['name'])
# build model
logger.info("Building the model of Conv-Tasnet")
Conv_Tasnet = model.Conv_TasNet(**opt['Conv_Tasnet'])
# build optimizer
logger.info("Building the optimizer of Conv-Tasnet")
optimizer = make_optimizer(Conv_Tasnet.parameters(), opt)
# build dataloader
logger.info('Building the dataloader of Conv-Tasnet')
train_dataloader, val_dataloader = make_dataloader(opt)
logger.info('Train Datasets Length: {}, Val Datasets Length: {}'.format(
len(train_dataloader), len(val_dataloader)))
# build scheduler
scheduler = ReduceLROnPlateau(
optimizer, mode='min',
factor=opt['scheduler']['factor'],
patience=opt['scheduler']['patience'],
verbose=True, min_lr=opt['scheduler']['min_lr'])
# build trainer
logger.info('Building the Trainer of Conv-Tasnet')
trainer = trainer_Tasnet.Trainer(train_dataloader, val_dataloader, Conv_Tasnet, optimizer, scheduler, opt)
trainer.run()
if __name__ == "__main__":
train()