28
28
tf .app .flags .DEFINE_integer ('rnn_size' , 15 , 'RNN feature size.' )
29
29
tf .app .flags .DEFINE_integer ('embedding_size' , 25 , 'Word embedding size.' )
30
30
tf .app .flags .DEFINE_integer ('min_word_frequency' , 20 , 'Word frequency cutoff.' )
31
+ tf .app .flags .DEFINE_boolean ('run_unit_tests' , False , 'If true, run tests.' )
32
+
31
33
FLAGS = tf .app .flags .FLAGS
32
34
33
35
# Define how to get data
@@ -75,6 +77,16 @@ def clean_text(text_string):
75
77
return (text_string )
76
78
77
79
80
+ # Test clean_text function
81
+ class clean_test (tf .test .TestCase ):
82
+ # Make sure cleaning function behaves correctly
83
+ def clean_string_test (self ):
84
+ with self .test_session ():
85
+ test_input = '--TensorFlow\' s so Great! Don\t you think so? '
86
+ test_expected = 'tensorflows so great don you think so'
87
+ test_out = clean_text (test_input )
88
+ self .assertEqual (test_expected , test_out )
89
+
78
90
# Define RNN Model
79
91
def rnn_model (x_data_ph , max_sequence_length , vocab_size , embedding_size ,
80
92
rnn_size , dropout_keep_prob ):
@@ -83,7 +95,7 @@ def rnn_model(x_data_ph, max_sequence_length, vocab_size, embedding_size,
83
95
embedding_output = tf .nn .embedding_lookup (embedding_mat , x_data_ph )
84
96
85
97
# Define the RNN cell
86
- cell = tf .nn . rnn_cell .BasicRNNCell (num_units = rnn_size )
98
+ cell = tf .contrib . rnn .BasicRNNCell (num_units = rnn_size )
87
99
output , state = tf .nn .dynamic_rnn (cell , embedding_output , dtype = tf .float32 )
88
100
output = tf .nn .dropout (output , dropout_keep_prob )
89
101
@@ -121,7 +133,6 @@ def main(args):
121
133
# Set model parameters
122
134
storage_folder = FLAGS .storage_folder
123
135
learning_rate = FLAGS .learning_rate
124
- epochs = FLAGS .epochs
125
136
run_unit_tests = FLAGS .run_unit_tests
126
137
epochs = FLAGS .epochs
127
138
batch_size = FLAGS .batch_size
@@ -226,15 +237,22 @@ def main(args):
226
237
227
238
# Run loss and accuracy for training
228
239
temp_train_loss , temp_train_acc = sess .run ([loss , accuracy ], feed_dict = train_dict )
240
+
229
241
test_dict = {x_data_ph : x_test , y_output_ph : y_test , dropout_keep_prob :1.0 }
230
242
temp_test_loss , temp_test_acc = sess .run ([loss , accuracy ], feed_dict = test_dict )
231
243
232
244
# Print Epoch Summary
245
+ print ('Epoch: {}, Train Loss:{:.2}, Train Acc: {:.2}' .format (epoch + 1 , temp_train_loss , temp_train_acc ))
233
246
print ('Epoch: {}, Test Loss: {:.2}, Test Acc: {:.2}' .format (epoch + 1 , temp_test_loss , temp_test_acc ))
234
247
235
248
# Save model every epoch
236
249
saver .save (sess , os .path .join (storage_folder , "model.ckpt" ))
237
250
238
251
# Run main module/tf App
239
252
if __name__ == "__main__" :
240
- tf .app .run ()
253
+ if FLAGS .run_unit_tests :
254
+ # Perform unit tests
255
+ tf .test .main ()
256
+ else :
257
+ # Run evaluation
258
+ tf .app .run ()
0 commit comments