Skip to content

Commit

Permalink
Merge pull request #25 from anshuman23/dev
Browse files Browse the repository at this point in the history
Added RNN (LSTM) example
  • Loading branch information
anshuman23 authored Jul 22, 2018
2 parents 6cbdec0 + 23c077a commit 9dd757b
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 0 deletions.
39 changes: 39 additions & 0 deletions examples/rnn-lstm-example/create_input_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
maxSeqLength = 250
batchSize = 24

import numpy as np
import tensorflow as tf
import re

wordsList = np.load('other_data/wordsList.npy').tolist()
wordsList = [word.decode('UTF-8') for word in wordsList]
wordVectors = np.load('other_data/wordVectors.npy')
strip_special_chars = re.compile("[^A-Za-z0-9 ]+")

def cleanSentences(string):
string = string.lower().replace("<br />", " ")
return re.sub(strip_special_chars, "", string.lower())

def getSentenceMatrix(sentence):
arr = np.zeros([batchSize, maxSeqLength])
sentenceMatrix = np.zeros([batchSize,maxSeqLength], dtype='int32')
cleanedSentence = cleanSentences(sentence)
split = cleanedSentence.split()
for indexCounter,word in enumerate(split):
try:
sentenceMatrix[0,indexCounter] = wordsList.index(word)
except ValueError:
sentenceMatrix[0,indexCounter] = 399999
return sentenceMatrix

inputText = "That movie was terrible."
inputMatrix = getSentenceMatrix(inputText)
print inputMatrix
print inputMatrix.shape
np.savetxt("inputMatrixNegative.csv", inputMatrix, delimiter=',', fmt="%i")

secondInputText = "That movie was the best one I have ever seen."
secondInputMatrix = getSentenceMatrix(secondInputText)
print secondInputMatrix
print secondInputMatrix.shape
np.savetxt("inputMatrixPositive.csv", secondInputMatrix, delimiter=',', fmt="%i")
24 changes: 24 additions & 0 deletions examples/rnn-lstm-example/freeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import tensorflow as tf
import os

model_dir = './model/'
output_node_names = 'add'

checkpoint = tf.train.get_checkpoint_state(model_dir)
input_checkpoint = checkpoint.model_checkpoint_path

absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1])
output_graph = absolute_model_dir + "/frozen_model_lstm.pb"

clear_devices = True

with tf.Session(graph=tf.Graph()) as sess:
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)
saver.restore(sess, input_checkpoint)
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(), output_node_names.split(","))

with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())

print("%d ops in the final graph." % len(output_graph_def.node))

Empty file.
Empty file.

0 comments on commit 9dd757b

Please sign in to comment.