Skip to content

Commit 164541a

Browse files
author
glample
committed
fix lstm tanh
1 parent 0a597fc commit 164541a

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

Diff for: network.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,9 @@ def link(self, input):
133133
def recurrence(x_t, c_tm1, h_tm1):
134134
i_t = T.nnet.sigmoid(T.dot(x_t, self.w_xi) + T.dot(h_tm1, self.w_hi) + T.dot(c_tm1, self.w_ci) + self.b_i)
135135
f_t = T.nnet.sigmoid(T.dot(x_t, self.w_xf) + T.dot(h_tm1, self.w_hf) + T.dot(c_tm1, self.w_cf) + self.b_f)
136-
c_t = f_t * c_tm1 + i_t * T.nnet.sigmoid(T.dot(x_t, self.w_xc) + T.dot(h_tm1, self.w_hc) + self.b_c)
136+
c_t = f_t * c_tm1 + i_t * T.tanh(T.dot(x_t, self.w_xc) + T.dot(h_tm1, self.w_hc) + self.b_c)
137137
o_t = T.nnet.sigmoid(T.dot(x_t, self.w_xo) + T.dot(h_tm1, self.w_ho) + T.dot(c_t, self.w_co) + self.b_o)
138-
h_t = o_t * T.nnet.sigmoid(c_t)
138+
h_t = o_t * T.tanh(c_t)
139139
return [c_t, h_t]
140140

141141
# If we used batches, we have to permute the first and second dimension.

Diff for: unit_tests.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,9 @@ def test_lstm():
279279
x_t = input_value[i]
280280
i_t = expit(np.dot(x_t, lstm.w_xi.get_value()) + np.dot(h_t, lstm.w_hi.get_value()) + np.dot(c_t, lstm.w_ci.get_value()) + lstm.b_i.get_value())
281281
f_t = expit(np.dot(x_t, lstm.w_xf.get_value()) + np.dot(h_t, lstm.w_hf.get_value()) + np.dot(c_t, lstm.w_cf.get_value()) + lstm.b_f.get_value())
282-
c_t = f_t * c_t + i_t * expit(np.dot(x_t, lstm.w_xc.get_value()) + np.dot(h_t, lstm.w_hc.get_value()) + lstm.b_c.get_value())
282+
c_t = f_t * c_t + i_t * np.tanh(np.dot(x_t, lstm.w_xc.get_value()) + np.dot(h_t, lstm.w_hc.get_value()) + lstm.b_c.get_value())
283283
o_t = expit(np.dot(x_t, lstm.w_xo.get_value()) + np.dot(h_t, lstm.w_ho.get_value()) + np.dot(c_t, lstm.w_co.get_value()) + lstm.b_o.get_value())
284-
h_t = o_t * expit(c_t)
284+
h_t = o_t * np.tanh(c_t)
285285

286286
assert h_t.shape == (hidden_dim,)
287287
np.testing.assert_array_almost_equal(
@@ -324,9 +324,9 @@ def test_lstm():
324324
x_t = input_value_dimshuffled[i]
325325
i_t = expit(np.dot(x_t, lstm.w_xi.get_value()) + np.dot(h_t, lstm.w_hi.get_value()) + np.dot(c_t, lstm.w_ci.get_value()) + lstm.b_i.get_value())
326326
f_t = expit(np.dot(x_t, lstm.w_xf.get_value()) + np.dot(h_t, lstm.w_hf.get_value()) + np.dot(c_t, lstm.w_cf.get_value()) + lstm.b_f.get_value())
327-
c_t = f_t * c_t + i_t * expit(np.dot(x_t, lstm.w_xc.get_value()) + np.dot(h_t, lstm.w_hc.get_value()) + lstm.b_c.get_value())
327+
c_t = f_t * c_t + i_t * np.tanh(np.dot(x_t, lstm.w_xc.get_value()) + np.dot(h_t, lstm.w_hc.get_value()) + lstm.b_c.get_value())
328328
o_t = expit(np.dot(x_t, lstm.w_xo.get_value()) + np.dot(h_t, lstm.w_ho.get_value()) + np.dot(c_t, lstm.w_co.get_value()) + lstm.b_o.get_value())
329-
h_t = o_t * expit(c_t)
329+
h_t = o_t * np.tanh(c_t)
330330

331331
assert h_t.shape == (input_value.shape[0], hidden_dim)
332332
np.testing.assert_array_almost_equal(output.eval({input: input_value}), h_t, decimal=3)
@@ -727,16 +727,16 @@ def test_conv2d_layer_kmax_pooling():
727727
"""
728728
"""
729729

730-
test_hidden_layer()
731-
test_embedding_layer()
732-
test_rnn()
730+
# test_hidden_layer()
731+
# test_embedding_layer()
732+
# test_rnn()
733733
test_lstm()
734-
test_kmax_pooling_layer_1()
735-
test_kmax_pooling_layer_2()
736-
test_conv1d_layer()
737-
test_conv2d_layer()
738-
test_conv1d_layer_kmax_pooling()
739-
test_conv2d_layer_kmax_pooling()
734+
# test_kmax_pooling_layer_1()
735+
# test_kmax_pooling_layer_2()
736+
# test_conv1d_layer()
737+
# test_conv2d_layer()
738+
# test_conv1d_layer_kmax_pooling()
739+
# test_conv2d_layer_kmax_pooling()
740740

741741

742742
exit()

0 commit comments

Comments
 (0)