2
2
from torch .autograd import Variable
3
3
import torch .nn .functional as F
4
4
import pickle
5
+ from sumeval .metrics .rouge import RougeCalculator
6
+ from sumeval .metrics .bleu import BLEUCalculator
7
+ from hyperdash import Experiment
5
8
6
9
import util
7
10
8
- def train (data_loader , dev_iter , encoder , decoder , mlp , args ):
11
+ def train_classification (data_loader , dev_iter , encoder , decoder , mlp , args ):
9
12
lr = args .lr
10
13
encoder_opt = torch .optim .Adam (encoder .parameters (), lr = lr )
11
14
decoder_opt = torch .optim .Adam (decoder .parameters (), lr = lr )
@@ -53,14 +56,13 @@ def train(data_loader, dev_iter, encoder, decoder, mlp, args):
53
56
input_label = target [0 ]
54
57
single_data = prob [0 ]
55
58
_ , predict_index = torch .max (single_data , 1 )
56
- input_sentence = util .transform_id2word (input_data , data_loader .dataset .index2word )
57
- predict_sentence = util .transform_id2word (predict_index , data_loader .dataset .index2word )
59
+ input_sentence = util .transform_id2word (input_data . data , data_loader .dataset .index2word , lang = "ja" )
60
+ predict_sentence = util .transform_id2word (predict_index . data , data_loader .dataset .index2word , lang = "ja" )
58
61
print ("Input Sentence:" )
59
62
print (input_sentence )
60
63
print ("Output Sentence:" )
61
64
print (predict_sentence )
62
- eval_model (encoder , mlp , input_data , input_label )
63
-
65
+ eval_classification (encoder , mlp , input_data , input_label )
64
66
65
67
if epoch % args .lr_decay_interval == 0 :
66
68
# decrease learning rate
@@ -91,13 +93,87 @@ def train(data_loader, dev_iter, encoder, decoder, mlp, args):
91
93
print ("Finish!!!" )
92
94
93
95
96
+ def train_reconstruction (train_loader , test_loader , encoder , decoder , args ):
97
+ lr = args .lr
98
+ encoder_opt = torch .optim .Adam (encoder .parameters (), lr = lr )
99
+ decoder_opt = torch .optim .Adam (decoder .parameters (), lr = lr )
100
+
101
+ encoder .train ()
102
+ decoder .train ()
103
+ steps = 0
104
+ for epoch in range (1 , args .epochs + 1 ):
105
+ print ("=======Epoch========" )
106
+ print (epoch )
107
+ for batch in train_loader :
108
+ feature = Variable (batch )
109
+ if args .use_cuda :
110
+ encoder .cuda ()
111
+ decoder .cuda ()
112
+ feature = feature .cuda ()
113
+
114
+ encoder_opt .zero_grad ()
115
+ decoder_opt .zero_grad ()
116
+
117
+ h = encoder (feature )
118
+ prob = decoder (h )
119
+ reconstruction_loss = compute_cross_entropy (prob , feature )
120
+ reconstruction_loss .backward ()
121
+ encoder_opt .step ()
122
+ decoder_opt .step ()
123
+
124
+ steps += 1
125
+ print ("Epoch: {}" .format (epoch ))
126
+ print ("Steps: {}" .format (steps ))
127
+ print ("Loss: {}" .format (reconstruction_loss .data [0 ]))
128
+ # check reconstructed sentence
129
+ if steps % args .log_interval == 0 :
130
+ print ("Test!!" )
131
+ input_data = feature [0 ]
132
+ single_data = prob [0 ]
133
+ _ , predict_index = torch .max (single_data , 1 )
134
+ input_sentence = util .transform_id2word (input_data .data , train_loader .dataset .index2word , lang = "en" )
135
+ predict_sentence = util .transform_id2word (predict_index .data , train_loader .dataset .index2word , lang = "en" )
136
+ print ("Input Sentence:" )
137
+ print (input_sentence )
138
+ print ("Output Sentence:" )
139
+ print (predict_sentence )
140
+
141
+ if epoch % args .test_interval == 0 :
142
+ eval_reconstruction (encoder , decoder , test_loader , args )
143
+
144
+
145
+ if epoch % args .lr_decay_interval == 0 :
146
+ # decrease learning rate
147
+ lr = lr / 5
148
+ encoder_opt = torch .optim .Adam (encoder .parameters (), lr = lr )
149
+ decoder_opt = torch .optim .Adam (decoder .parameters (), lr = lr )
150
+ encoder .train ()
151
+ decoder .train ()
152
+
153
+ if epoch % args .save_interval == 0 :
154
+ util .save_models (encoder , args .save_dir , "encoder" , steps )
155
+ util .save_models (decoder , args .save_dir , "decoder" , steps )
156
+
157
+ # finalization
158
+ # save vocabulary
159
+ with open ("word2index" , "wb" ) as w2i , open ("index2word" , "wb" ) as i2w :
160
+ pickle .dump (train_loader .dataset .word2index , w2i )
161
+ pickle .dump (train_loader .dataset .index2word , i2w )
162
+
163
+ # save models
164
+ util .save_models (encoder , args .save_dir , "encoder" , "final" )
165
+ util .save_models (decoder , args .save_dir , "decoder" , "final" )
166
+
167
+ print ("Finish!!!" )
168
+
169
+
94
170
def compute_cross_entropy (log_prob , target ):
95
171
# compute reconstruction loss using cross entropy
96
172
loss = [F .nll_loss (sentence_emb_matrix , word_ids , size_average = False ) for sentence_emb_matrix , word_ids in zip (log_prob , target )]
97
173
average_loss = sum ([torch .sum (l ) for l in loss ]) / log_prob .size ()[0 ]
98
174
return average_loss
99
175
100
- def eval_model (encoder , mlp , feature , label ):
176
+ def eval_classification (encoder , mlp , feature , label ):
101
177
encoder .eval ()
102
178
mlp .eval ()
103
179
h = encoder (feature )
@@ -110,3 +186,45 @@ def eval_model(encoder, mlp, feature, label):
110
186
encoder .train ()
111
187
mlp .train ()
112
188
189
+
190
+ def eval_reconstruction (encoder , decoder , data_iter , args ):
191
+ print ("Eval" )
192
+ encoder .eval ()
193
+ decoder .eval ()
194
+ avg_loss = 0
195
+ rouge_1 = 0.0
196
+ rouge_2 = 0.0
197
+ index2word = data_iter .dataset .index2word
198
+ for batch in data_iter :
199
+ feature = Variable (batch )
200
+ if args .use_cuda :
201
+ feature = feature .cuda ()
202
+ h = encoder (feature )
203
+ prob = decoder (h )
204
+ _ , predict_index = torch .max (prob , 2 )
205
+ original_sentences = [util .transform_id2word (sentence , index2word , "en" ) for sentence in batch ]
206
+ predict_sentences = [util .transform_id2word (sentence , index2word , "en" ) for sentence in predict_index .data ]
207
+ r1 , r2 = calc_rouge (original_sentences , predict_sentences )
208
+ rouge_1 += r1
209
+ rouge_2 += r2
210
+ reconstruction_loss = compute_cross_entropy (prob , feature )
211
+ avg_loss += reconstruction_loss .data [0 ]
212
+ avg_loss = avg_loss / len (data_iter .dataset )
213
+ rouge_1 = rouge_1 / len (data_iter .dataset )
214
+ rouge_2 = rouge_2 / len (data_iter .dataset )
215
+ print ("Evaluation - loss: {} Rouge1: {} Rouge2: {}" .format (avg_loss , rouge_1 , rouge_2 ))
216
+ encoder .train ()
217
+ decoder .train ()
218
+
219
+ def calc_rouge (original_sentences , predict_sentences ):
220
+ rouge_1 = 0.0
221
+ rouge_2 = 0.0
222
+ for original , predict in zip (original_sentences , predict_sentences ):
223
+ # Remove padding
224
+ original , predict = original .replace ("<PAD>" , "" ).strip (), predict .replace ("<PAD>" , "" ).strip ()
225
+ rouge = RougeCalculator (stopwords = True , lang = "en" )
226
+ r1 = rouge .rouge_1 (summary = predict , references = original )
227
+ r2 = rouge .rouge_2 (summary = predict , references = original )
228
+ rouge_1 += r1
229
+ rouge_2 += r2
230
+ return rouge_1 , rouge_2
0 commit comments