-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #25 from anshuman23/dev
Added RNN (LSTM) example
- Loading branch information
Showing
4 changed files
with
63 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
maxSeqLength = 250 | ||
batchSize = 24 | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
import re | ||
|
||
wordsList = np.load('other_data/wordsList.npy').tolist() | ||
wordsList = [word.decode('UTF-8') for word in wordsList] | ||
wordVectors = np.load('other_data/wordVectors.npy') | ||
strip_special_chars = re.compile("[^A-Za-z0-9 ]+") | ||
|
||
def cleanSentences(string): | ||
string = string.lower().replace("<br />", " ") | ||
return re.sub(strip_special_chars, "", string.lower()) | ||
|
||
def getSentenceMatrix(sentence): | ||
arr = np.zeros([batchSize, maxSeqLength]) | ||
sentenceMatrix = np.zeros([batchSize,maxSeqLength], dtype='int32') | ||
cleanedSentence = cleanSentences(sentence) | ||
split = cleanedSentence.split() | ||
for indexCounter,word in enumerate(split): | ||
try: | ||
sentenceMatrix[0,indexCounter] = wordsList.index(word) | ||
except ValueError: | ||
sentenceMatrix[0,indexCounter] = 399999 | ||
return sentenceMatrix | ||
|
||
inputText = "That movie was terrible." | ||
inputMatrix = getSentenceMatrix(inputText) | ||
print inputMatrix | ||
print inputMatrix.shape | ||
np.savetxt("inputMatrixNegative.csv", inputMatrix, delimiter=',', fmt="%i") | ||
|
||
secondInputText = "That movie was the best one I have ever seen." | ||
secondInputMatrix = getSentenceMatrix(secondInputText) | ||
print secondInputMatrix | ||
print secondInputMatrix.shape | ||
np.savetxt("inputMatrixPositive.csv", secondInputMatrix, delimiter=',', fmt="%i") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import tensorflow as tf | ||
import os | ||
|
||
model_dir = './model/' | ||
output_node_names = 'add' | ||
|
||
checkpoint = tf.train.get_checkpoint_state(model_dir) | ||
input_checkpoint = checkpoint.model_checkpoint_path | ||
|
||
absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1]) | ||
output_graph = absolute_model_dir + "/frozen_model_lstm.pb" | ||
|
||
clear_devices = True | ||
|
||
with tf.Session(graph=tf.Graph()) as sess: | ||
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices) | ||
saver.restore(sess, input_checkpoint) | ||
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(), output_node_names.split(",")) | ||
|
||
with tf.gfile.GFile(output_graph, "wb") as f: | ||
f.write(output_graph_def.SerializeToString()) | ||
|
||
print("%d ops in the final graph." % len(output_graph_def.node)) | ||
|
Empty file.
Empty file.