forked from carpedm20/pixel-rnn-tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 1
/
statistic.py
75 lines (57 loc) · 2.43 KB
/
statistic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import numpy as np
import tensorflow as tf
from logging import getLogger
logger = getLogger(__name__)
class Statistic(object):
def __init__(self, sess, data, model_dir, variables, test_step, max_to_keep=20):
self.sess = sess
self.test_step = test_step
self.reset()
with tf.variable_scope('t'):
self.t_op = tf.Variable(0, trainable=False, name='t')
self.t_add_op = self.t_op.assign_add(1)
self.model_dir = model_dir
self.saver = tf.train.Saver(variables + [self.t_op], max_to_keep=max_to_keep)
self.writer = tf.summary.FileWriter('./logs/%s' % self.model_dir, self.sess.graph)
with tf.variable_scope('summary'):
scalar_summary_tags = ['train_l', 'test_l']
self.summary_placeholders = {}
self.summary_ops = {}
for tag in scalar_summary_tags:
self.summary_placeholders[tag] = tf.placeholder('float32', None, name=tag.replace(' ', '_'))
self.summary_ops[tag] = tf.summary.scalar('%s/%s' % (data, tag), self.summary_placeholders[tag])
def reset(self):
pass
def on_step(self, train_l, test_l):
self.t = self.t_add_op.eval(session=self.sess)
self.inject_summary({'train_l': train_l, 'test_l': test_l}, self.t)
self.save_model(self.t)
self.reset()
def get_t(self):
return self.t_op.eval(session=self.sess)
def inject_summary(self, tag_dict, t):
summary_str_lists = self.sess.run([self.summary_ops[tag] for tag in list(tag_dict.keys())], {
self.summary_placeholders[tag]: value for tag, value in list(tag_dict.items())
})
for summary_str in summary_str_lists:
self.writer.add_summary(summary_str, t)
def save_model(self, t):
logger.info("Saving checkpoints...")
model_name = type(self).__name__
if not os.path.exists(self.model_dir):
os.makedirs(self.model_dir)
self.saver.save(self.sess, self.model_dir, global_step=t)
def load_model(self):
logger.info("Initializing all variables")
tf.global_variables_initializer().run()
logger.info("Loading checkpoints...")
ckpt = tf.train.get_checkpoint_state(self.model_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
fname = os.path.join(self.model_dir, ckpt_name)
self.saver.restore(self.sess, fname)
logger.info("Load SUCCESS: %s" % fname)
else:
logger.info("Load FAILED: %s" % self.model_dir)
self.t = self.t_add_op.eval(session=self.sess)