-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
57 lines (38 loc) · 1.14 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
from rnn import *
from process_data import *
def test_rnn():
#RnnNumpy.forward_propagation = forward_propagation
t_data = RNNTokenizer("data\\reddit-comments-2015-08.csv")
X_train, Y_train = t_data.tokenize_data()
np.random.seed(10)
model = RnnNumpy(t_data.voc_size)
o,s = model.forward_propagation(X_train[10])
print X_train[10]
print o.shape
print o
print "Expected: %f" % np.log(t_data.voc_size)
print "Actual loss: %f" % model.calc_loss(X_train[:1000],Y_train[:1000])
np.random.seed(10)
model = RnnNumpy(t_data.voc_size)
model.train_with_sgd(X_train[:100],Y_train[:100],nepoch=10, evaluate_loss_after=1)
num_sentences = 1
sentence_min_length = 7
for i in range(num_sentences):
sent = []
while len(sent) < sentence_min_length:
sent = model.generate_sentence(t_data)
print " ".join(sent)
"""
predictions = model.predict(X_train[10])
print predictions.shape
print predictions
"""
"""
grad_check_vocab_size = 100
np.random.seed(10)
model = RnnNumpy(grad_check_vocab_size,10,bptt_trunacate=1000)
model.gradient_check([0,1,2,3],[1,2,3,4])
"""
if __name__ == '__main__':
test_rnn()