-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrunTraining.py
139 lines (110 loc) · 4.89 KB
/
runTraining.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#
# Copyright 2017 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import neuralnetwork as nn
tf.logging.set_verbosity(tf.logging.DEBUG)
TRAIN_FILE = 'train_images.tfrecords'
VALIDATION_FILE = 'val_images.tfrecords'
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
flags.DEFINE_float('decay_rate', 1.0, 'Learning rate decay.')
flags.DEFINE_integer('decay_steps', 1000, 'Steps at each learning rate.')
flags.DEFINE_integer('num_epochs', 1, 'Number of epochs to run trainer.')
flags.DEFINE_integer('batch_size', 1, 'Batch size.')
flags.DEFINE_string('data_dir', '/tmp/sunny_data',
'Directory with the training data.')
flags.DEFINE_string('checkpoint_dir', '/tmp/sunny_train',
"""Directory where to write model checkpoints.""")
def run_training():
# construct the graph
with tf.Graph().as_default():
# specify the training data file location
trainfile = os.path.join(FLAGS.data_dir, TRAIN_FILE)
# read the images and labels
images, labels = nn.inputs(batch_size=FLAGS.batch_size,
num_epochs=FLAGS.num_epochs,
filename=trainfile)
# run inference on the images
results = nn.inference(images)
# calculate the loss from the results of inference and the labels
loss = nn.loss(results, labels)
# setup the training operations
train_op = nn.training(loss, FLAGS.learning_rate, FLAGS.decay_steps,
FLAGS.decay_rate)
# setup the summary ops to use TensorBoard
summary_op = tf.summary.merge_all()
# init to setup the initial values of the weights
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
# setup a saver for saving checkpoints
saver = tf.train.Saver()
# create the session
sess = tf.Session()
# specify where to write the log files for import to TensorBoard
summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_dir,
sess.graph)
# initialize the graph
sess.run(init_op)
# setup the coordinato and threadsr. Used for multiple threads to read data.
# Not strictly required since we don't have a lot of data but typically
# using multiple threads to read data improves performance
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# loop will continue until we run out of input training cases
try:
step = 0
while not coord.should_stop():
# start time and run one training iteration
start_time = time.time()
_, loss_value = sess.run([train_op, loss])
duration = time.time() - start_time
# print some output periodically
if step % 100 == 0:
print('OUTPUT: Step %d: loss = %.3f (%.3f sec)' % (step,
loss_value,
duration))
# output some data to the log files for tensorboard
summary_str = sess.run(summary_op)
summary_writer.add_summary(summary_str, step)
summary_writer.flush()
# less frequently output checkpoint files. Used for evaluating the model
if step % 1000 == 0:
checkpoint_path = os.path.join(FLAGS.checkpoint_dir,
'model.ckpt')
saver.save(sess, checkpoint_path, global_step=step)
step += 1
# quit after we run out of input files to read
except tf.errors.OutOfRangeError:
print('OUTPUT: Done training for %d epochs, %d steps.' % (FLAGS.num_epochs,
step))
checkpoint_path = os.path.join(FLAGS.checkpoint_dir,
'model.ckpt')
saver.save(sess, checkpoint_path, global_step=step)
finally:
coord.request_stop()
# shut down the threads gracefully
coord.join(threads)
sess.close()
def main(_):
run_training()
if __name__ == '__main__':
tf.app.run()