-
Notifications
You must be signed in to change notification settings - Fork 48
/
eval.py
83 lines (61 loc) · 3.08 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import tensorflow as tf
import numpy as np
import os
import data_helpers
# Parameters
# ==================================================
# Data loading params
tf.flags.DEFINE_string("pos_dir", "data/rt-polaritydata/rt-polarity.pos", "Path of positive data")
tf.flags.DEFINE_string("neg_dir", "data/rt-polaritydata/rt-polarity.neg", "Path of negative data")
# Eval Parameters
tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (Default: 64)")
tf.flags.DEFINE_string("checkpoint_dir", "", "Checkpoint directory from training run")
# Misc Parameters
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")
FLAGS = tf.flags.FLAGS
FLAGS._parse_flags()
print("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):
print("{} = {}".format(attr.upper(), value))
print("")
def eval():
with tf.device('/cpu:0'):
x_text, y = data_helpers.load_data_and_labels(FLAGS.pos_dir, FLAGS.neg_dir)
# Map data into vocabulary
text_path = os.path.join(FLAGS.checkpoint_dir, "..", "text_vocab")
text_vocab_processor = tf.contrib.learn.preprocessing.VocabularyProcessor.restore(text_path)
x_eval = np.array(list(text_vocab_processor.transform(x_text)))
y_eval = np.argmax(y, axis=1)
checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
graph = tf.Graph()
with graph.as_default():
session_conf = tf.ConfigProto(
allow_soft_placement=FLAGS.allow_soft_placement,
log_device_placement=FLAGS.log_device_placement)
sess = tf.Session(config=session_conf)
with sess.as_default():
# Load the saved meta graph and restore variables
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess, checkpoint_file)
# Get the placeholders from the graph by name
input_text = graph.get_operation_by_name("input_text").outputs[0]
# input_y = graph.get_operation_by_name("input_y").outputs[0]
dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
# Tensors we want to evaluate
predictions = graph.get_operation_by_name("output/predictions").outputs[0]
# Generate batches for one epoch
batches = data_helpers.batch_iter(list(x_eval), FLAGS.batch_size, 1, shuffle=False)
# Collect the predictions here
all_predictions = []
for x_batch in batches:
batch_predictions = sess.run(predictions, {input_text: x_batch,
dropout_keep_prob: 1.0})
all_predictions = np.concatenate([all_predictions, batch_predictions])
correct_predictions = float(sum(all_predictions == y_eval))
print("Total number of test examples: {}".format(len(y_eval)))
print("Accuracy: {:g}".format(correct_predictions / float(len(y_eval))))
def main(_):
eval()
if __name__ == "__main__":
tf.app.run()