-
Notifications
You must be signed in to change notification settings - Fork 1
/
sst_main.py
226 lines (178 loc) · 8.74 KB
/
sst_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import math
import tensorflow as tf
from configs import cfg
from src.dataset import Dataset, RawDataProcessor
from src.evaluator import Evaluator
# choose model
from src.graph_handler import GraphHandler
from src.perform_recorder import PerformRecoder
from src.utils.file import load_file, save_file
from src.utils.record_log import _logger
from sst_log_analysis import do_analyse_sst
# choose model
network_type = cfg.network_type
if network_type == 'mtsa':
from src.model.m_mtsa import ModelMTSA as Model
model_set = ['mtsa']
def train():
output_model_params()
loadFile = True
ifLoad, data = False, None
if loadFile:
ifLoad, data = load_file(cfg.processed_path, 'processed data', 'pickle')
if not ifLoad or not loadFile:
raw_data = RawDataProcessor(cfg.data_dir)
train_data_list = raw_data.get_data_list('train')
dev_data_list = raw_data.get_data_list('dev')
test_data_list = raw_data.get_data_list('test')
train_data_obj = Dataset(train_data_list, 'train')
dev_data_obj = Dataset(dev_data_list, 'dev', train_data_obj.dicts)
test_data_obj = Dataset(test_data_list, 'test', train_data_obj.dicts)
save_file({'train_data_obj': train_data_obj, 'dev_data_obj': dev_data_obj, 'test_data_obj': test_data_obj},
cfg.processed_path)
train_data_obj.save_dict(cfg.dict_path)
else:
train_data_obj = data['train_data_obj']
dev_data_obj = data['dev_data_obj']
test_data_obj = data['test_data_obj']
train_data_obj.filter_data(cfg.only_sentence, cfg.fine_grained)
dev_data_obj.filter_data(True, cfg.fine_grained)
test_data_obj.filter_data(True, cfg.fine_grained)
emb_mat_token, emb_mat_glove = train_data_obj.emb_mat_token, train_data_obj.emb_mat_glove
with tf.variable_scope(network_type) as scope:
if network_type in model_set:
model = Model(emb_mat_token, emb_mat_glove, len(train_data_obj.dicts['token']),
len(train_data_obj.dicts['char']), train_data_obj.max_lens['token'], scope.name)
graphHandler = GraphHandler(model)
evaluator = Evaluator(model)
performRecoder = PerformRecoder(3)
if cfg.gpu_mem is None:
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=cfg.gpu_mem,
allow_growth=True)
graph_config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
else:
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=cfg.gpu_mem)
graph_config = tf.ConfigProto(gpu_options=gpu_options)
# graph_config.gpu_options.allow_growth = True
sess = tf.Session(config=graph_config)
graphHandler.initialize(sess)
# begin training
steps_per_epoch = int(math.ceil(1.0 * train_data_obj.sample_num / cfg.train_batch_size))
num_steps = cfg.num_steps or steps_per_epoch * cfg.max_epoch
global_step = 0
for sample_batch, batch_num, data_round, idx_b in train_data_obj.generate_batch_sample_iter(num_steps):
global_step = sess.run(model.global_step) + 1
if_get_summary = global_step % (cfg.log_period or steps_per_epoch) == 0
loss, summary, train_op = model.step(sess, sample_batch, get_summary=if_get_summary)
if global_step % 100 == 0:
_logger.add('data round: %d: %d/%d, global step:%d -- loss: %.4f' %
(data_round, idx_b, batch_num, global_step, loss))
if if_get_summary:
graphHandler.add_summary(summary, global_step)
# Occasional evaluation
if global_step % (cfg.eval_period or steps_per_epoch) == 0:
# ---- dev ----
dev_loss, dev_accu, dev_sent_accu = evaluator.get_evaluation(
sess, dev_data_obj, global_step
)
_logger.add('==> for dev, loss: %.4f, accuracy: %.4f, sentence accuracy: %.4f' %
(dev_loss, dev_accu, dev_sent_accu))
# ---- test ----
test_loss, test_accu, test_sent_accu = evaluator.get_evaluation(
sess, test_data_obj, global_step
)
_logger.add('~~> for test, loss: %.4f, accuracy: %.4f, sentence accuracy: %.4f' %
(test_loss, test_accu, test_sent_accu))
# ---- train ----
# train_loss, train_accu, train_sent_accu = evaluator.get_evaluation(
# sess, train_data_obj, global_step
# )
# _logger.add('--> for train, loss: %.4f, accuracy: %.4f, sentence accuracy: %.4f' %
# (train_loss, train_accu, train_sent_accu))
is_in_top, deleted_step = performRecoder.update_top_list(global_step, dev_accu, sess)
if is_in_top and global_step > 30000: # todo-ed: delete me to run normally
# evaluator.get_evaluation_file_output(sess, dev_data_obj, global_step, deleted_step)
evaluator.get_evaluation_file_output(sess, test_data_obj, global_step, deleted_step)
this_epoch_time, mean_epoch_time = cfg.time_counter.update_data_round(data_round)
if this_epoch_time is not None and mean_epoch_time is not None:
_logger.add('##> this epoch time: %f, mean epoch time: %f' % (this_epoch_time, mean_epoch_time))
# finally save
# if global_step % (cfg.save_period or steps_per_epoch) != 0:
# graphHandler.save(sess, global_step)
do_analyse_sst(_logger.path)
def test():
output_model_params()
loadFile = True
ifLoad, data = False, None
if loadFile:
ifLoad, data = load_file(cfg.processed_path, 'processed data', 'pickle')
if not ifLoad or not loadFile:
raw_data = RawDataProcessor(cfg.data_dir)
train_data_list = raw_data.get_data_list('train')
dev_data_list = raw_data.get_data_list('dev')
test_data_list = raw_data.get_data_list('test')
train_data_obj = Dataset(train_data_list, 'train')
dev_data_obj = Dataset(dev_data_list, 'dev', train_data_obj.dicts)
test_data_obj = Dataset(test_data_list, 'test', train_data_obj.dicts)
save_file({'train_data_obj': train_data_obj, 'dev_data_obj': dev_data_obj, 'test_data_obj': test_data_obj},
cfg.processed_path)
train_data_obj.save_dict(cfg.dict_path)
else:
train_data_obj = data['train_data_obj']
dev_data_obj = data['dev_data_obj']
test_data_obj = data['test_data_obj']
train_data_obj.filter_data(True)
dev_data_obj.filter_data(True)
test_data_obj.filter_data(True)
emb_mat_token, emb_mat_glove = train_data_obj.emb_mat_token, train_data_obj.emb_mat_glove
with tf.variable_scope(network_type) as scope:
if network_type in model_set:
model = Model(emb_mat_token, emb_mat_glove, len(train_data_obj.dicts['token']),
len(train_data_obj.dicts['char']), train_data_obj.max_lens['token'], scope.name)
graphHandler = GraphHandler(model)
evaluator = Evaluator(model)
if cfg.gpu_mem is None:
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=cfg.gpu_mem,
allow_growth=True)
graph_config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
else:
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=cfg.gpu_mem)
graph_config = tf.ConfigProto(gpu_options=gpu_options)
# graph_config.gpu_options.allow_growth = True
sess = tf.Session(config=graph_config)
graphHandler.initialize(sess)
# todo: test model
# ---- dev ----
dev_loss, dev_accu, dev_sent_accu = evaluator.get_evaluation(
sess, dev_data_obj, 1
)
_logger.add('==> for dev, loss: %.4f, accuracy: %.4f, sentence accuracy: %.4f' %
(dev_loss, dev_accu, dev_sent_accu))
# ---- test ----
test_loss, test_accu, test_sent_accu = evaluator.get_evaluation(
sess, test_data_obj, 1
)
_logger.add('~~> for test, loss: %.4f, accuracy: %.4f, sentence accuracy: %.4f' %
(test_loss, test_accu, test_sent_accu))
# ---- train ----
train_loss, train_accu, train_sent_accu = evaluator.get_evaluation(
sess, train_data_obj, 1
)
_logger.add('--> for test, loss: %.4f, accuracy: %.4f, sentence accuracy: %.4f' %
(train_loss, train_accu, train_sent_accu))
def main(_):
if cfg.mode == 'train':
train()
elif cfg.mode == 'test':
test()
else:
raise RuntimeError('no running mode named as %s' % cfg.mode)
def output_model_params():
_logger.add()
_logger.add('==>model_title: ' + cfg.model_name[1:])
_logger.add()
for key,value in cfg.args.__dict__.items():
if key not in ['test','shuffle']:
_logger.add('%s: %s' % (key, value))
if __name__ == '__main__':
tf.app.run()