Skip to content

Commit 52e658b

Browse files
committed
code reconstruct:
1. move all decoders to a single file. 2. build decoder seperately from the RNN model and pass decoder to RNN model via parameter. 3. remove the mask_select codes from all decoders (duplicated codes) to RNN model.
1 parent eaeea62 commit 52e658b

File tree

3 files changed

+247
-247
lines changed

3 files changed

+247
-247
lines changed

decoder.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
#!/usr/bin/env python
2+
# encoding: utf-8
3+
4+
import torch
5+
import torch.nn as nn
6+
from torch.nn.parameter import Parameter
7+
from torch.autograd import Variable
8+
9+
import pdb
10+
from collections import defaultdict
11+
from utils import list2longtensor, map_dict_value
12+
from alias_multinomial import AliasMethod
13+
14+
class ListModule(nn.Module):
15+
def __init__(self, *args):
16+
super(ListModule, self).__init__()
17+
idx = 0
18+
for module in args:
19+
self.add_module(str(idx), module)
20+
idx += 1
21+
22+
def __getitem__(self, idx):
23+
if idx < 0 or idx >= len(self._modules):
24+
raise IndexError('index {} is out of range'.format(idx))
25+
it = iter(self._modules.values())
26+
for i in range(idx):
27+
next(it)
28+
return next(it)
29+
30+
def __iter__(self):
31+
return iter(self._modules.values())
32+
33+
def __len__(self):
34+
return len(self._modules)
35+
36+
class SMDecoder(nn.Module):
37+
def __init__(self, nhid, ntoken):
38+
super(SMDecoder, self).__init__()
39+
self.nhid = nhid
40+
self.decoder = nn.Linear(nhid, ntoken)
41+
self.CE = nn.CrossEntropyLoss()
42+
43+
def init_weights(self):
44+
initrange = 0.1
45+
self.decoder.bias.data.fill_(0)
46+
self.decoder.weight.data.uniform_(-initrange, initrange)
47+
48+
def forward(self, input):
49+
return self.decoder(input)
50+
51+
def forward_with_loss(self, rnn_output, target):
52+
output = self(rnn_output)
53+
return self.CE(output, target)
54+
55+
class ClassBasedSMDecoder(nn.Module):
56+
def __init__(self, nhid, ncls, word2cls, class_chunks):
57+
super(ClassBasedSMDecoder, self).__init__()
58+
self.nhid = nhid
59+
self.cls_decoder = nn.Linear(nhid, ncls)
60+
61+
words_decoders = []
62+
for c in class_chunks:
63+
words_decoders.append(nn.Linear(nhid, c))
64+
self.words_decoders = ListModule(*words_decoders)
65+
66+
self.CELoss = nn.CrossEntropyLoss(size_average=False)
67+
68+
# collect word in the same class
69+
cls_cluster = defaultdict(lambda: [])
70+
71+
# the within index of each words in their word cluster
72+
within_cls_idx = []
73+
for i, c in enumerate(word2cls):
74+
within_cls_idx.append(len(cls_cluster[c]))
75+
cls_cluster[c].append(i)
76+
77+
self.word2cls = list2longtensor(word2cls)
78+
self.within_cls_idx = list2longtensor(within_cls_idx)
79+
self.cls_cluster = map_dict_value(list2longtensor, cls_cluster)
80+
81+
def init_weights(self):
82+
r = .1
83+
self.cls_decoder.weight.data.uniform_(-r, r)
84+
self.cls_decoder.bias.data.fill_(0)
85+
for word_decoder in self.words_decoders:
86+
word_decoder.weight.data.uniform_(-r, r)
87+
word_decoder.bias.data.fill_(0)
88+
89+
def build_labels(self, target):
90+
#TODO: too much time is wasted in this function
91+
92+
# cls idx of each word
93+
cls_idx = self.word2cls.index_select(0, target)
94+
# word within class idx of each word
95+
within_cls_idx = self.within_cls_idx.index_select(0, target)
96+
97+
cls_idx_ = cls_idx.data.cpu()
98+
wci = within_cls_idx.data.cpu()
99+
100+
# collect the batch index of words in the same class
101+
within_batch_idx_dic = defaultdict(lambda: [])
102+
# collect the within index of words in the same class
103+
within_cls_idx_dic = defaultdict(lambda: [])
104+
105+
for i, (c, w) in enumerate(zip(cls_idx_, wci)):
106+
within_batch_idx_dic[c].append(i)
107+
within_cls_idx_dic[c].append(w)
108+
109+
within_batch_idx_dic = map_dict_value(list2longtensor, within_batch_idx_dic)
110+
within_cls_idx_dic = map_dict_value(list2longtensor, within_cls_idx_dic)
111+
112+
return cls_idx, within_cls_idx_dic, within_batch_idx_dic
113+
114+
def forward(self, input, within_batch_idx):
115+
p_class = self.cls_decoder(input)
116+
p_words = {}
117+
118+
for c in within_batch_idx:
119+
d = input.index_select(0, within_batch_idx[c])
120+
p_words[c] = self.words_decoders[c](d)
121+
122+
return p_class, p_words
123+
124+
def forward_with_loss(self, rnn_output, target):
125+
126+
cls_idx, within_cls_idx, within_batch_idx = self.build_labels(target)
127+
128+
p_class, p_words = self(rnn_output, within_batch_idx)
129+
130+
# by applying log function, the product of class prob and word prob can be break down,
131+
# hence we can calculate the class and word CE loss respectively.
132+
133+
closs = self.CELoss(p_class, cls_idx)
134+
wloss = []
135+
for c in p_words:
136+
wloss.append(self.CELoss(p_words[c], within_cls_idx[c]))
137+
138+
return (closs + sum(wloss)) / len(cls_idx)
139+
140+
class NCEDecoder(nn.Module):
141+
def __init__(self, nhid, ntoken, noise_dist, nsample=10):
142+
super(NCEDecoder, self).__init__()
143+
self.nhid = nhid
144+
self.word_embeddings = nn.Embedding(ntoken, nhid)
145+
self.word_bias = nn.Embedding(ntoken, 1)
146+
147+
noise_dist = noise_dist / noise_dist.sum()
148+
self.noise_dist = noise_dist.cuda()
149+
self.alias = AliasMethod(self.noise_dist)
150+
self.nsample = nsample
151+
self.norm = 9
152+
153+
self.CE = nn.CrossEntropyLoss()
154+
self.valid = False
155+
156+
def init_weights(self):
157+
initrange = 0.1
158+
self.word_embeddings.weight.data.uniform_(-initrange, initrange)
159+
self.word_bias.weight.data.fill_(0)
160+
161+
def _get_noise_prob(self, indices):
162+
return Variable(self.noise_dist[indices.data.view(-1)].view_as(indices))
163+
164+
def forward(self, input, target):
165+
#model prob for target and sample words
166+
167+
sample = Variable(self.alias.draw(input.size(0), self.nsample).cuda())
168+
indices = torch.cat([target.unsqueeze(1), sample], dim=1)
169+
170+
embed = self.word_embeddings(indices)
171+
bias = self.word_bias(indices)
172+
173+
score = torch.baddbmm(1, bias, 1, embed, input.unsqueeze(2)).squeeze()
174+
score = score.sub(self.norm).exp()
175+
target_prob, sample_prob = score[:, 0], score[:, 1:]
176+
177+
return target_prob, sample_prob, sample
178+
179+
def nce_loss(self, target_prob, sample_prob, target, sample):
180+
target_noise_prob = self._get_noise_prob(target)
181+
sample_noise_prob = self._get_noise_prob(sample)
182+
183+
def log(tensor):
184+
EPSILON = 1e-10
185+
return torch.log(EPSILON + tensor)
186+
187+
target_loss = log(
188+
target_prob / (target_prob + self.nsample * target_noise_prob)
189+
)
190+
191+
sample_loss = log(
192+
self.nsample * sample_noise_prob / (sample_prob + self.nsample * sample_noise_prob)
193+
)
194+
195+
return - (target_loss + torch.sum(sample_loss, -1).squeeze())
196+
197+
def forward_with_loss(self, rnn_output, target):
198+
199+
if self.training:
200+
target_prob, sample_prob, sample = self(rnn_output, target)
201+
loss = self.nce_loss(target_prob, sample_prob, target, sample)
202+
return loss.mean()
203+
else:
204+
output = torch.addmm(
205+
1, self.word_bias.weight.view(-1), 1, rnn_output, self.word_embeddings.weight.t()
206+
)
207+
return self.CE(output, target)

0 commit comments

Comments
 (0)