-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrunAE.py
executable file
·66 lines (47 loc) · 1.95 KB
/
runAE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""
runner function for autoencoder
Timo Flesch, 2017
"""
# external
import numpy as np
import tensorflow as tf
# custom
from ae.model import myModel
from nntools.trainer import trainModel
from nntools.evaluation import evalModel
from nntools.io import loadMyModel
from datetime import datetime
FLAGS = tf.app.flags.FLAGS
def runAE(x_train,x_test):
# checkpoint run model folder
ckpt_dir_run = FLAGS.ckpt_dir + 'model_' + FLAGS.model
log_dir_run = FLAGS.log_dir+'model_'+FLAGS.model
if not(tf.gfile.Exists(ckpt_dir_run)):
tf.gfile.MakeDirs(ckpt_dir_run)
if not(tf.gfile.Exists(log_dir_run)):
tf.gfile.MakeDirs(log_dir_run)
with tf.Session() as sess:
if FLAGS.do_training:
nnet = myModel(lr = FLAGS.learning_rate,
optimizer = FLAGS.optimizer,
nonlinearity = FLAGS.nonlinearity,
)
print("{} Now training Autoencoder, LR: {} , EPs: {}, BS: {}"
.format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'),FLAGS.learning_rate,FLAGS.n_training_episodes, FLAGS.batch_size))
# initialize all variables
nnet.init_graph_vars(sess,log_dir=log_dir_run)
# train model
results = trainModel(sess,nnet,x_train,x_test,
n_episodes = FLAGS.n_training_episodes,
n_batches = FLAGS.n_training_batches,
batch_size = FLAGS.batch_size,
model_dir = ckpt_dir_run)
evalModel(sess,nnet,x_train)
else:
nnet = myModel(is_trained=True)
print("Now evaluating Autoencoder")
ops = loadMyModel(sess,['nnet'],ckpt_dir_run)
print(ops)
nnet.y_hat = ops[0]
results = evalModel(sess,nnet)
return results