Skip to content

Commit

Permalink
add the feed_dict argument for initialize_rnn_state
Browse files Browse the repository at this point in the history
  • Loading branch information
Tbabm committed Dec 3, 2017
1 parent c13d8da commit 3e791b5
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions tensorlayer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,25 +115,26 @@ def set_name_reuse(enable=True):
"""
set_keep['name_reuse'] = enable

def initialize_rnn_state(state):
def initialize_rnn_state(state, feed_dict=None):
"""Return the initialized RNN state.
The input is LSTMStateTuple or State of RNNCells.
The inputs are LSTMStateTuple or State of RNNCells and an optional feed_dict.
Parameters
-----------
state : a RNN state.
feed_dict : a dictionary.
"""
try: # TF1.0
LSTMStateTuple = tf.contrib.rnn.LSTMStateTuple
except:
LSTMStateTuple = tf.nn.rnn_cell.LSTMStateTuple

if isinstance(state, LSTMStateTuple):
c = state.c.eval()
h = state.h.eval()
c = state.c.eval(feed_dict=feed_dict)
h = state.h.eval(feed_dict=feed_dict)
return (c, h)
else:
new_state = state.eval()
new_state = state.eval(feed_dict=feed_dict)
return new_state

def print_all_variables(train_only=False):
Expand Down

0 comments on commit 3e791b5

Please sign in to comment.