-
Notifications
You must be signed in to change notification settings - Fork 16
/
train.py
114 lines (92 loc) · 3.58 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import pickle
import sys
from datetime import datetime as dt
import editdistance
import numpy as np
import theano as th
import rnn_ctc.neuralnet as nn
# from parscribe import ParScribe as Scribe
from scribe import Scribe
import utils
import telugu as lang
import utils
############################################ Read Args
args = utils.read_args(sys.argv[1:])
num_samples, num_epochs = args['num_samples'], args['num_epochs']
scribe_args, nnet_args = args['scribe_args'], args['nnet_args']
if len(sys.argv) > 1:
output_fname = '-'.join(sorted(sys.argv[1:]))
output_fname = output_fname.replace('.ast', '').replace('/', '').replace('configs', '')
else:
output_fname = "default"
network_fname = '{}.pkl'.format(output_fname)
output_fname += '_' + dt.now().strftime('%y%m%d_%H%M') + '.txt'
distances, wts = [], []
print("Output will be written to: ", output_fname)
# Initialize Language
lang.select_labeler(args['labeler'])
alphabet_size = len(lang.symbols)
# Initialize Scriber
scribe_args['dtype'] = th.config.floatX
scriber = Scribe(lang, **scribe_args)
printer = utils.Printer(lang.symbols)
sys.setrecursionlimit(1000000)
# Initialize the Neural Network
if os.path.exists(network_fname):
print('Loading existing network file')
with open(network_fname, 'rb') as fh:
ntwk = pickle.load(fh)
else:
print('Building the Network')
ntwk = nn.NeuralNet(scriber.height, alphabet_size, **nnet_args)
with open(network_fname, 'wb') as fh:
pickle.dump(ntwk, fh)
# Print
print('\nArguments:')
utils.write_dict(args)
print('FloatX: {}'.format(th.config.floatX))
print('Alphabet Size: {}'.format(alphabet_size))
################################ Train
print('Training the Network')
for epoch in range(num_epochs):
ntwk.update_learning_rate(epoch)
edit_dist, tot_len = 0, 0
print('Epoch: {} '.format(epoch))
# keeping 1 backup file as data might get lost, if script is stopped while pickling
os.rename(network_fname, 'ntwk.bkp.pkl')
with open(network_fname, 'wb') as fh:
pickle.dump(ntwk, fh)
print('Network saved to {}'.format(network_fname))
for samp in range(num_samples):
x, _, y = scriber.get_text_image()
y_blanked = utils.insert_blanks(y, alphabet_size, num_blanks_at_start=2)
# if len(y_blanked) < 2:
# print(y_blanked, end=' ')
# continue
cst, pred, forward_probs = ntwk.trainer(x, y_blanked)
if np.isinf(cst):
printer.show_all(y, x, pred,
(forward_probs > 1e-20, 'Forward probabilities:', y_blanked))
print('Exiting on account of Inf Cost...')
break
if samp == 0 and epoch==num_epochs-1: # or len(y) == 0:
pred, hidden = ntwk.tester(x)
print('Epoch:{:6d} Cost:{:.3f}'.format(epoch, float(cst)))
printer.show_all(y, x, pred,
(forward_probs > -6, 'Forward probabilities:', y_blanked),
((hidden + 1)/2, 'Hidden Layer:'))
utils.pprint_probs(forward_probs)
edit_dist += editdistance.eval(printer.decode(pred), y)
tot_len += len(y)
distances.append((edit_dist, tot_len))
# wts.append(ntwk.layers[0].params[1].get_value())
# print("Successes: {0[0]}/{0[1]}".format(edit_dist))
################################ save
with open(output_fname, 'w') as f:
# pickle.dump((wts, successes), f, -1)
utils.write_dict(args, f)
f.write("Edit Distances\n")
for i, (e, t) in enumerate(distances):
f.write("{:4d}: {:5d}/{:5d}\n".format(i, e, t))
print(output_fname, distances[-1])