forked from TimoBolkart/voca
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_training.py
82 lines (65 loc) · 2.97 KB
/
run_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
'''
Max-Planck-Gesellschaft zur Foerderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights on this
computer program.
You can only use this computer program if you have closed a license agreement with MPG or you get the right to use
the computer program from someone who is authorized to grant you that right.
Any use of the computer program without a valid license is prohibited and liable to prosecution.
Copyright 2019 Max-Planck-Gesellschaft zur Foerderung der Wissenschaften e.V. (MPG). acting on behalf of its
Max Planck Institute for Intelligent Systems and the Max Planck Institute for Biological Cybernetics.
All rights reserved.
More information about VOCA is available at http://voca.is.tue.mpg.de.
For comments or questions, please email us at voca@tue.mpg.de
'''
import os
import stat
import glob
import shutil
import subprocess
import configparser
import tensorflow as tf
from config_parser import read_config, create_default_config
from utils.data_handler import DataHandler
from utils.batcher import Batcher
from utils.voca_model import VOCAModel as Model
def main():
# Prior to training, please adapt the hyper parameters in the config_parser.py and run the script to generate
# the training config file use to train your own VOCA model.
pkg_path, _ = os.path.split(os.path.realpath(__file__))
init_config_fname = os.path.join(pkg_path, 'training_config.cfg')
if not os.path.exists(init_config_fname):
print('Config not found %s' % init_config_fname)
create_default_config(init_config_fname)
config = configparser.ConfigParser()
config.read(init_config_fname)
# Path to cache the processed audio
config.set('Input Output', 'processed_audio_path', './training_data/processed_audio_%s.pkl' % config.get('Audio Parameters', 'audio_feature_type'))
checkpoint_dir = config.get('Input Output', 'checkpoint_dir')
if os.path.exists(checkpoint_dir):
print('Checkpoint dir already exists %s' % checkpoint_dir)
key = input('Press "q" to quit, "x" to erase existing folder, and any other key to continue training: ')
if key.lower() == 'q':
return
elif key.lower() == 'x':
try:
shutil.rmtree(checkpoint_dir, ignore_errors=True)
except:
print('Failed deleting checkpoint directory')
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
config_fname = os.path.join(checkpoint_dir, 'config.pkl')
if os.path.exists(config_fname):
print('Use existing config %s' % config_fname)
else:
with open(config_fname, 'w') as fp:
config.write(fp)
fp.close()
config = read_config(config_fname)
data_handler = DataHandler(config)
batcher = Batcher(data_handler)
with tf.Session() as session:
model = Model(session=session, config=config, batcher=batcher)
model.build_graph()
model.load()
model.train()
if __name__ == '__main__':
main()