Skip to content

Commit 1f1a2ed

Browse files
committed
1.fix BasicRNN import error for production train.
2.Add unit test for produciton train. 3.Rename 06_siamese_similarity_model as siamese_similarity_model, so that it can be used by 06_siamese_similarity_driver.py
1 parent 4e9015d commit 1f1a2ed

File tree

3 files changed

+24
-7
lines changed

3 files changed

+24
-7
lines changed

09_Recurrent_Neural_Networks/06_Training_A_Siamese_Similarity_Measure/06_siamese_similarity_driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# Here, we show how to perform address matching
66
# with a Siamese RNN model
77

8-
import os
98
import random
109
import string
1110
import numpy as np
@@ -178,3 +177,4 @@ def address2onehot(address,
178177
plt.title('Accuracy and Loss of Siamese RNN')
179178
plt.grid()
180179
plt.legend(loc='lower right')
180+
plt.show()
Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# Here, we show how to perform address matching
66
# with a Siamese RNN model
77

8-
import numpy as np
98
import tensorflow as tf
109

1110

@@ -81,7 +80,7 @@ def loss(scores, y_target, margin):
8180

8281
# If y-target is -1 to 1, then do the following
8382
#pos_mult = tf.add(tf.multiply(0.5, y_target), 0.5)
84-
# Else if y-target is 0 to 1, then do the folloing
83+
# Else if y-target is 0 to 1, then do the following
8584
pos_mult = tf.cast(y_target, tf.float32)
8685

8786
# Make sure positive losses are on similar strings
@@ -99,7 +98,7 @@ def loss(scores, y_target, margin):
9998
# Combine similar and dissimilar losses
10099
loss = tf.add(positive_loss, negative_loss)
101100

102-
# Create the margin term. This is when the targets are 0.,
101+
# Create the margin term. This is when the targets are 0.,
103102
# and the scores are less than m, return 0.
104103

105104
# Check if target is zero (dissimilar strings)

10_Taking_TensorFlow_to_Production/05_Production_Example/05_production_ex_train.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
tf.app.flags.DEFINE_integer('rnn_size', 15, 'RNN feature size.')
2929
tf.app.flags.DEFINE_integer('embedding_size', 25, 'Word embedding size.')
3030
tf.app.flags.DEFINE_integer('min_word_frequency', 20, 'Word frequency cutoff.')
31+
tf.app.flags.DEFINE_boolean('run_unit_tests', False, 'If true, run tests.')
32+
3133
FLAGS = tf.app.flags.FLAGS
3234

3335
# Define how to get data
@@ -75,6 +77,16 @@ def clean_text(text_string):
7577
return(text_string)
7678

7779

80+
# Test clean_text function
81+
class clean_test(tf.test.TestCase):
82+
# Make sure cleaning function behaves correctly
83+
def clean_string_test(self):
84+
with self.test_session():
85+
test_input = '--TensorFlow\'s so Great! Don\t you think so? '
86+
test_expected = 'tensorflows so great don you think so'
87+
test_out = clean_text(test_input)
88+
self.assertEqual(test_expected, test_out)
89+
7890
# Define RNN Model
7991
def rnn_model(x_data_ph, max_sequence_length, vocab_size, embedding_size,
8092
rnn_size, dropout_keep_prob):
@@ -83,7 +95,7 @@ def rnn_model(x_data_ph, max_sequence_length, vocab_size, embedding_size,
8395
embedding_output = tf.nn.embedding_lookup(embedding_mat, x_data_ph)
8496

8597
# Define the RNN cell
86-
cell = tf.nn.rnn_cell.BasicRNNCell(num_units = rnn_size)
98+
cell = tf.contrib.rnn.BasicRNNCell(num_units = rnn_size)
8799
output, state = tf.nn.dynamic_rnn(cell, embedding_output, dtype=tf.float32)
88100
output = tf.nn.dropout(output, dropout_keep_prob)
89101

@@ -121,7 +133,6 @@ def main(args):
121133
# Set model parameters
122134
storage_folder = FLAGS.storage_folder
123135
learning_rate = FLAGS.learning_rate
124-
epochs = FLAGS.epochs
125136
run_unit_tests = FLAGS.run_unit_tests
126137
epochs = FLAGS.epochs
127138
batch_size = FLAGS.batch_size
@@ -226,15 +237,22 @@ def main(args):
226237

227238
# Run loss and accuracy for training
228239
temp_train_loss, temp_train_acc = sess.run([loss, accuracy], feed_dict=train_dict)
240+
229241
test_dict = {x_data_ph: x_test, y_output_ph: y_test, dropout_keep_prob:1.0}
230242
temp_test_loss, temp_test_acc = sess.run([loss, accuracy], feed_dict=test_dict)
231243

232244
# Print Epoch Summary
245+
print('Epoch: {}, Train Loss:{:.2}, Train Acc: {:.2}'.format(epoch+1, temp_train_loss, temp_train_acc))
233246
print('Epoch: {}, Test Loss: {:.2}, Test Acc: {:.2}'.format(epoch+1, temp_test_loss, temp_test_acc))
234247

235248
# Save model every epoch
236249
saver.save(sess, os.path.join(storage_folder, "model.ckpt"))
237250

238251
# Run main module/tf App
239252
if __name__ == "__main__":
240-
tf.app.run()
253+
if FLAGS.run_unit_tests:
254+
# Perform unit tests
255+
tf.test.main()
256+
else:
257+
# Run evaluation
258+
tf.app.run()

0 commit comments

Comments
 (0)