-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathtrain_aperoidic.py
24 lines (21 loc) · 979 Bytes
/
train_aperoidic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import hparams
from model.wavenet_model import *
from data.dataset import TimbreDataset
from model.timbre_training import *
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = WaveNetModel(hparams.create_aperiodic_hparams(), device).to(device)
print('model: ', model)
print('receptive field: ', model.receptive_field)
print('parameter count: ', model.parameter_count())
trainer = ModelTrainer(model=model,
data_folder='data/timbre_model',
lr=0.0005,
weight_decay=0.0,
snapshot_path='./snapshots/aperiodic',
snapshot_name='aper',
snapshot_interval=50000,
device=device)
#epoch = trainer.load_checkpoint('/Users/zhaowenxiao/pythonProj/torch_npss/snapshots/aperiodic/chaconne_model_1021_2019-03-30_09-32-23')
print('start training...')
trainer.train(batch_size=32,
epochs=1650)