Skip to content

Commit 3efac19

Browse files
committed
Add paragraph reconstruction experiment code
1 parent 8568904 commit 3efac19

File tree

6 files changed

+260
-18
lines changed

6 files changed

+260
-18
lines changed

datasets.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,67 @@
44
from tqdm import tqdm
55

66
from collections import Counter
7+
from copy import deepcopy
8+
9+
10+
def load_hotel_review_data(path, sentence_len):
11+
"""
12+
Load Hotel Reviews data from pickle distributed in https://drive.google.com/file/d/0B52eYWrYWqIpQzhBNkVxaV9mMjQ/view
13+
This file is published in https://github.com/dreasysnail/textCNN_public
14+
15+
:param path: pickle path
16+
:return:
17+
"""
18+
import _pickle as cPickle
19+
with open(path, "rb") as f:
20+
data = cPickle.load(f, encoding="latin1")
21+
22+
train_data, test_data = HotelReviewsDataset(data[0], deepcopy(data[2]), deepcopy(data[3]), sentence_len, transform=ToTensor()), \
23+
HotelReviewsDataset(data[1], deepcopy(data[2]), deepcopy(data[3]), sentence_len, transform=ToTensor())
24+
return train_data, test_data
25+
26+
27+
class HotelReviewsDataset(Dataset):
28+
"""
29+
Hotel Reviews Dataset
30+
"""
31+
def __init__(self, data_list, word2index, index2word, sentence_len, transform=None):
32+
self.word2index = word2index
33+
self.index2word = index2word
34+
self.n_words = len(self.word2index)
35+
self.data = data_list
36+
self.sentence_len = sentence_len
37+
self.transform = transform
38+
self.word2index["<PAD>"] = self.n_words
39+
self.index2word[self.n_words] = "<PAD>"
40+
self.n_words += 1
41+
print(self.index2word)
42+
temp_list = []
43+
for sentence in tqdm(self.data):
44+
if len(sentence) > self.sentence_len:
45+
# truncate sentence if sentence length is longer than `sentence_len`
46+
temp_list.append(np.array(sentence[:self.sentence_len]))
47+
else:
48+
# pad sentence with '<PAD>' token if sentence length is shorter than `sentence_len`
49+
sent_array = np.lib.pad(np.array(sentence),
50+
(0, self.sentence_len - len(sentence)),
51+
"constant",
52+
constant_values=(self.n_words-1, self.n_words-1))
53+
temp_list.append(sent_array)
54+
self.data = np.array(temp_list, dtype=np.int32)
55+
56+
57+
def __len__(self):
58+
return len(self.data)
59+
60+
def __getitem__(self, idx):
61+
data = self.data[idx]
62+
if self.transform:
63+
data = self.transform(data)
64+
return data
65+
66+
def vocab_lennght(self):
67+
return len(self.word2index)
768

869

970
class TextClassificationDataset(Dataset):
@@ -81,7 +142,8 @@ def __getitem__(self, idx):
81142
sample = {"sentence": sentence, "label": label}
82143

83144
if self.transform:
84-
sample = self.transform(sample)
145+
sample = {"sentence": self.transform(sample["sentence"]),
146+
"label": self.transform(sample["label"])}
85147

86148
return sample
87149

@@ -91,8 +153,5 @@ def vocab_length(self):
91153

92154
class ToTensor(object):
93155
"""Convert ndarrays in sample to Tensors."""
94-
95-
def __call__(self, sample):
96-
sentence, label = sample["sentence"], sample['label']
97-
return {'sentence': torch.from_numpy(sentence).type(torch.LongTensor),
98-
'label': torch.from_numpy(label).type(torch.LongTensor)}
156+
def __call__(self, data):
157+
return torch.from_numpy(data).type(torch.LongTensor)

main.py renamed to main_classification.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import model
66
from datasets import TextClassificationDataset, ToTensor
7-
from train import train
7+
from train import train_classification
88

99
import argparse
1010

@@ -17,7 +17,7 @@ def main():
1717
parser.add_argument('-batch_size', type=int, default=64, help='batch size for training')
1818
parser.add_argument('-lr_decay_interval', type=int, default=20,
1919
help='how many epochs to wait before decrease learning rate')
20-
parser.add_argument('-log_interval', type=int, default=256,
20+
parser.add_argument('-log_interval', type=int, default=16,
2121
help='how many steps to wait before logging training status')
2222
parser.add_argument('-test_interval', type=int, default=100,
2323
help='how many steps to wait before testing')
@@ -68,7 +68,7 @@ def main():
6868
decoder = torch.load(args.dec_snapshot)
6969
mlp = torch.load(args.mlp_snapshot)
7070

71-
train(data_loader, data_loader, encoder, decoder, mlp, args)
71+
train_classification(data_loader, data_loader, encoder, decoder, mlp, args)
7272

7373
if __name__ == '__main__':
7474
main()

main_reconstruction.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.utils.data import DataLoader
4+
5+
import model
6+
from datasets import TextClassificationDataset, ToTensor, load_hotel_review_data
7+
from train import train_reconstruction
8+
9+
import argparse
10+
11+
12+
def main():
13+
parser = argparse.ArgumentParser(description='text convolution-deconvolution auto-encoder model')
14+
# learning
15+
parser.add_argument('-lr', type=float, default=0.001, help='initial learning rate')
16+
parser.add_argument('-epochs', type=int, default=10, help='number of epochs for train')
17+
parser.add_argument('-batch_size', type=int, default=32, help='batch size for training')
18+
parser.add_argument('-lr_decay_interval', type=int, default=4,
19+
help='how many epochs to wait before decrease learning rate')
20+
parser.add_argument('-log_interval', type=int, default=256,
21+
help='how many steps to wait before logging training status')
22+
parser.add_argument('-test_interval', type=int, default=10,
23+
help='how many epochs to wait before testing')
24+
parser.add_argument('-save_interval', type=int, default=2,
25+
help='how many epochs to wait before saving')
26+
parser.add_argument('-save_dir', type=str, default='rec_snapshot', help='where to save the snapshot')
27+
# data
28+
parser.add_argument('-data_path', type=str, help='data path')
29+
parser.add_argument('-shuffle', default=False, help='shuffle data every epoch')
30+
parser.add_argument('-sentence_len', type=int, default=210, help='how many tokens in a sentence')
31+
# model
32+
parser.add_argument('-embed_dim', type=int, default=300, help='number of embedding dimension')
33+
parser.add_argument('-kernel_sizes', type=int, default=2,
34+
help='kernel size to use for convolution')
35+
parser.add_argument('-tau', type=float, default=0.01, help='temperature parameter')
36+
parser.add_argument('-use_cuda', action='store_true', default=True, help='whether using cuda')
37+
# option
38+
parser.add_argument('-enc_snapshot', type=str, default=None, help='filename of encoder snapshot ')
39+
parser.add_argument('-dec_snapshot', type=str, default=None, help='filename of decoder snapshot ')
40+
args = parser.parse_args()
41+
42+
train_data, test_data = load_hotel_review_data(args.data_path, args.sentence_len)
43+
train_loader, test_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=args.shuffle),\
44+
DataLoader(test_data, batch_size=args.batch_size, shuffle=args.shuffle)
45+
46+
k = args.embed_dim
47+
v = train_data.vocab_lennght()
48+
if args.enc_snapshot is None or args.dec_snapshot is None or args.mlp_snapshot is None:
49+
print("Start from initial")
50+
embedding = nn.Embedding(v, k, max_norm=1.0, norm_type=2.0)
51+
52+
encoder = model.ConvolutionEncoder(embedding)
53+
decoder = model.DeconvolutionDecoder(embedding, args.tau)
54+
else:
55+
print("Restart from snapshot")
56+
encoder = torch.load(args.enc_snapshot)
57+
decoder = torch.load(args.dec_snapshot)
58+
59+
train_reconstruction(train_loader, test_loader, encoder, decoder, args)
60+
61+
if __name__ == '__main__':
62+
main()

model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(self, embedding, tau):
4646
self.bn1 = nn.BatchNorm2d(600)
4747
self.deconvs2 = nn.ConvTranspose2d(600, 300, (2, 1), stride=2)
4848
self.bn2 = nn.BatchNorm2d(300)
49-
self.deconvs3 = nn.ConvTranspose2d(300, 1, (2+1, self.embed.weight.size()[1]), stride=2)
49+
self.deconvs3 = nn.ConvTranspose2d(300, 1, (2+2, self.embed.weight.size()[1]), stride=2)
5050

5151
# weight initialize for conv_transpose layer
5252
for m in self.modules():

train.py

Lines changed: 124 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
from torch.autograd import Variable
33
import torch.nn.functional as F
44
import pickle
5+
from sumeval.metrics.rouge import RougeCalculator
6+
from sumeval.metrics.bleu import BLEUCalculator
7+
from hyperdash import Experiment
58

69
import util
710

8-
def train(data_loader, dev_iter, encoder, decoder, mlp, args):
11+
def train_classification(data_loader, dev_iter, encoder, decoder, mlp, args):
912
lr = args.lr
1013
encoder_opt = torch.optim.Adam(encoder.parameters(), lr=lr)
1114
decoder_opt = torch.optim.Adam(decoder.parameters(), lr=lr)
@@ -53,14 +56,13 @@ def train(data_loader, dev_iter, encoder, decoder, mlp, args):
5356
input_label = target[0]
5457
single_data = prob[0]
5558
_, predict_index = torch.max(single_data, 1)
56-
input_sentence = util.transform_id2word(input_data, data_loader.dataset.index2word)
57-
predict_sentence = util.transform_id2word(predict_index, data_loader.dataset.index2word)
59+
input_sentence = util.transform_id2word(input_data.data, data_loader.dataset.index2word, lang="ja")
60+
predict_sentence = util.transform_id2word(predict_index.data, data_loader.dataset.index2word, lang="ja")
5861
print("Input Sentence:")
5962
print(input_sentence)
6063
print("Output Sentence:")
6164
print(predict_sentence)
62-
eval_model(encoder, mlp, input_data, input_label)
63-
65+
eval_classification(encoder, mlp, input_data, input_label)
6466

6567
if epoch % args.lr_decay_interval == 0:
6668
# decrease learning rate
@@ -91,13 +93,87 @@ def train(data_loader, dev_iter, encoder, decoder, mlp, args):
9193
print("Finish!!!")
9294

9395

96+
def train_reconstruction(train_loader, test_loader, encoder, decoder, args):
97+
lr = args.lr
98+
encoder_opt = torch.optim.Adam(encoder.parameters(), lr=lr)
99+
decoder_opt = torch.optim.Adam(decoder.parameters(), lr=lr)
100+
101+
encoder.train()
102+
decoder.train()
103+
steps = 0
104+
for epoch in range(1, args.epochs+1):
105+
print("=======Epoch========")
106+
print(epoch)
107+
for batch in train_loader:
108+
feature = Variable(batch)
109+
if args.use_cuda:
110+
encoder.cuda()
111+
decoder.cuda()
112+
feature = feature.cuda()
113+
114+
encoder_opt.zero_grad()
115+
decoder_opt.zero_grad()
116+
117+
h = encoder(feature)
118+
prob = decoder(h)
119+
reconstruction_loss = compute_cross_entropy(prob, feature)
120+
reconstruction_loss.backward()
121+
encoder_opt.step()
122+
decoder_opt.step()
123+
124+
steps += 1
125+
print("Epoch: {}".format(epoch))
126+
print("Steps: {}".format(steps))
127+
print("Loss: {}".format(reconstruction_loss.data[0]))
128+
# check reconstructed sentence
129+
if steps % args.log_interval == 0:
130+
print("Test!!")
131+
input_data = feature[0]
132+
single_data = prob[0]
133+
_, predict_index = torch.max(single_data, 1)
134+
input_sentence = util.transform_id2word(input_data.data, train_loader.dataset.index2word, lang="en")
135+
predict_sentence = util.transform_id2word(predict_index.data, train_loader.dataset.index2word, lang="en")
136+
print("Input Sentence:")
137+
print(input_sentence)
138+
print("Output Sentence:")
139+
print(predict_sentence)
140+
141+
if epoch % args.test_interval == 0:
142+
eval_reconstruction(encoder, decoder, test_loader, args)
143+
144+
145+
if epoch % args.lr_decay_interval == 0:
146+
# decrease learning rate
147+
lr = lr / 5
148+
encoder_opt = torch.optim.Adam(encoder.parameters(), lr=lr)
149+
decoder_opt = torch.optim.Adam(decoder.parameters(), lr=lr)
150+
encoder.train()
151+
decoder.train()
152+
153+
if epoch % args.save_interval == 0:
154+
util.save_models(encoder, args.save_dir, "encoder", steps)
155+
util.save_models(decoder, args.save_dir, "decoder", steps)
156+
157+
# finalization
158+
# save vocabulary
159+
with open("word2index", "wb") as w2i, open("index2word", "wb") as i2w:
160+
pickle.dump(train_loader.dataset.word2index, w2i)
161+
pickle.dump(train_loader.dataset.index2word, i2w)
162+
163+
# save models
164+
util.save_models(encoder, args.save_dir, "encoder", "final")
165+
util.save_models(decoder, args.save_dir, "decoder", "final")
166+
167+
print("Finish!!!")
168+
169+
94170
def compute_cross_entropy(log_prob, target):
95171
# compute reconstruction loss using cross entropy
96172
loss = [F.nll_loss(sentence_emb_matrix, word_ids, size_average=False) for sentence_emb_matrix, word_ids in zip(log_prob, target)]
97173
average_loss = sum([torch.sum(l) for l in loss]) / log_prob.size()[0]
98174
return average_loss
99175

100-
def eval_model(encoder, mlp, feature, label):
176+
def eval_classification(encoder, mlp, feature, label):
101177
encoder.eval()
102178
mlp.eval()
103179
h = encoder(feature)
@@ -110,3 +186,45 @@ def eval_model(encoder, mlp, feature, label):
110186
encoder.train()
111187
mlp.train()
112188

189+
190+
def eval_reconstruction(encoder, decoder, data_iter, args):
191+
print("Eval")
192+
encoder.eval()
193+
decoder.eval()
194+
avg_loss = 0
195+
rouge_1 = 0.0
196+
rouge_2 = 0.0
197+
index2word = data_iter.dataset.index2word
198+
for batch in data_iter:
199+
feature = Variable(batch)
200+
if args.use_cuda:
201+
feature = feature.cuda()
202+
h = encoder(feature)
203+
prob = decoder(h)
204+
_, predict_index = torch.max(prob, 2)
205+
original_sentences = [util.transform_id2word(sentence, index2word, "en") for sentence in batch]
206+
predict_sentences = [util.transform_id2word(sentence, index2word, "en") for sentence in predict_index.data]
207+
r1, r2 = calc_rouge(original_sentences, predict_sentences)
208+
rouge_1 += r1
209+
rouge_2 += r2
210+
reconstruction_loss = compute_cross_entropy(prob, feature)
211+
avg_loss += reconstruction_loss.data[0]
212+
avg_loss = avg_loss / len(data_iter.dataset)
213+
rouge_1 = rouge_1 / len(data_iter.dataset)
214+
rouge_2 = rouge_2 / len(data_iter.dataset)
215+
print("Evaluation - loss: {} Rouge1: {} Rouge2: {}".format(avg_loss, rouge_1, rouge_2))
216+
encoder.train()
217+
decoder.train()
218+
219+
def calc_rouge(original_sentences, predict_sentences):
220+
rouge_1 = 0.0
221+
rouge_2 = 0.0
222+
for original, predict in zip(original_sentences, predict_sentences):
223+
# Remove padding
224+
original, predict = original.replace("<PAD>", "").strip(), predict.replace("<PAD>", "").strip()
225+
rouge = RougeCalculator(stopwords=True, lang="en")
226+
r1 = rouge.rouge_1(summary=predict, references=original)
227+
r2 = rouge.rouge_2(summary=predict, references=original)
228+
rouge_1 += r1
229+
rouge_2 += r2
230+
return rouge_1, rouge_2

util.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
import math
33
import os
44

5-
def transform_id2word(index, id2word):
6-
return "".join([id2word[idx.data[0]] for idx in index])
5+
def transform_id2word(index, id2word, lang):
6+
if lang == "ja":
7+
return "".join([id2word[idx] for idx in index])
8+
else:
9+
return " ".join([id2word[idx] for idx in index])
710

811
def sigmoid_annealing_schedule(step, max_step, param_init=1.0, param_final=0.01, gain=0.3):
912
return ((param_init - param_final) / (1 + math.exp(gain * (step - (max_step / 2))))) + param_final

0 commit comments

Comments
 (0)