-
Notifications
You must be signed in to change notification settings - Fork 1
/
dimi-trainer.py
116 lines (90 loc) · 3.28 KB
/
dimi-trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import multiprocessing
import pickle
import sys
import itertools
import scripts.dimi_io as io
import configparser
import scripts.dimi as dimi
import os
from random import randint, random
import time
def main(argv):
if len(argv) < 1:
sys.stderr.write("One required argument: <Config file|Resume directory>\n")
sys.exit(-1)
path = argv[0]
D, K, init_alpha = 0, 0, 0
if len(argv) == 3:
D, K = argv[1], argv[2]
elif len(argv) == 4:
D, K, init_alpha = argv[1], argv[2], argv[3]
if not os.path.exists(path):
sys.stderr.write("Input file/dir does not exist!\n")
sys.exit(-1)
config = configparser.ConfigParser()
input_seqs_file = None
time.sleep(random() * 10)
if os.path.isdir(path):
## Resume mode
config.read(path + "/config.ini")
out_dir = config.get('io', 'output_dir')
resume = True
else:
config.read(argv[0])
input_seqs_file = config.get('io', 'init_seqs', fallback=None)
if not input_seqs_file is None:
del config['io']['init_seqs']
out_dir = config.get('io', 'output_dir')
if not D and not K:
D = config.get('params', 'd')
K = config.get('params', 'k')
if not init_alpha:
init_alpha = config.get('params', 'init_alpha')
init_alpha = str(float(init_alpha))
config['params']['d'] = D
config['params']['k'] = K
if init_alpha:
config['params']['init_alpha'] = init_alpha
out_dir += '_D'+D+'K'+K+'A'+init_alpha
counter = itertools.count()
for i in counter:
new_out_dir = out_dir + '_{}'.format(i)
if not os.path.exists(new_out_dir):
os.makedirs(new_out_dir)
out_dir = new_out_dir
config['io']['output_dir'] = out_dir
sys.stderr.write("The output directory for this run is {}.\n".format(out_dir))
break
resume = False
with open(out_dir + "/config.ini", 'w') as configfile:
config.write(configfile)
## Write git hash of current branch to out directory
os.system('git rev-parse HEAD > %s/git-rev.txt' % (out_dir))
input_file = config.get('io', 'input_file')
working_dir = config.get('io', 'working_dir', fallback=out_dir)
dict_file = config.get('io', 'dict_file')
punct_dict_file = config.get('io', 'punct_dict_file', fallback=None)
## Read in input file to get sequence for X
(pos_seq, word_seq) = io.read_input_file(input_file)
params = read_params(config)
params['output_dir'] = out_dir
dimi.wrapped_sample_beam(word_seq, params, working_dir,
word_dict_file = dict_file, resume=resume, punct_dict_file=punct_dict_file)
def read_params(config):
params = {}
for (key, val) in config.items('io'):
params[key] = val
for (key, val) in config.items('params'):
params[key] = val
return params
if __name__ == "__main__":
try:
multiprocessing.set_start_method("fork")
# loky.set_start_method("loky")
except:
ctx = multiprocessing.get_start_method()
print(ctx)
if sys.version_info[0] != 3:
print("This script requires Python 3")
exit()
main(sys.argv[1:])