-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathweight_drop.py
103 lines (90 loc) · 4.14 KB
/
weight_drop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
from torch.nn import Parameter
from functools import wraps
import functools
class WeightDrop(torch.nn.Module):
def __init__(self, module, weights, dropout=0, variational=False):
super(WeightDrop, self).__init__()
self.module = module
self.weights = weights
self.dropout = dropout
self.variational = variational
self._setup()
def widget_demagnetizer_y2k_edition(*args, **kwargs):
# We need to replace flatten_parameters with a nothing function
# It must be a function rather than a lambda as otherwise pickling explodes
# We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION!
# (╯°□°)╯︵ ┻━┻
return
def _setup(self):
# Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN
if issubclass(type(self.module), torch.nn.RNNBase):
self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition
for name_w in self.weights:
#print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
w = getattr(self.module, name_w)
del self.module._parameters[name_w]
self.module.register_parameter(name_w + '_raw', Parameter(w.data))
def _setweights(self):
for name_w in self.weights:
raw_w = getattr(self.module, name_w + '_raw')
w = None
if self.variational:
mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1))
if raw_w.is_cuda: mask = mask.cuda()
mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
w = torch.nn.Parameter(mask.expand_as(raw_w) * raw_w)
else:
w = torch.nn.Parameter(torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training))
setattr(self.module, name_w, w)
def forward(self, *args):
self._setweights()
return self.module.forward(*args)
def rsetattr(obj, attr, val):
pre, _, post = attr.rpartition('.')
return setattr(rgetattr(obj, pre) if pre else obj, post, val)
def rgetattr(obj, attr, *args):
def _getattr(obj, attr):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))
class ParameterListWeightDrop(torch.nn.Module):
def __init__(self, module, weights, dropout=0, variational=False):
super(ParameterListWeightDrop, self).__init__()
self.module = module
self.weights = weights
self.parents = {}
for w in self.weights:
p = '.'.join(w.split('.')[:-1])
i = int(w.split('.')[-1])
if p not in self.parents:
self.parents[p] = []
self.parents[p].append(i)
self.dropout = dropout
self.variational = variational
self._setup()
def _setup(self):
for name_w in self.parents:
#print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
ws = rgetattr(self.module, name_w)
rsetattr(self.module, name_w, None)
rsetattr(self.module, name_w + '_raw', torch.nn.ParameterList(ws))
def _setweights(self):
for name_w in self.parents:
raw_ws = rgetattr(self.module, name_w + '_raw')
ws = []
for i, raw_w in enumerate(raw_ws):
if i in self.parents[name_w]:
if self.variational:
mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1))
if raw_w.is_cuda: mask = mask.cuda()
mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
w = torch.nn.Parameter(mask.expand_as(raw_w) * raw_w)
else:
w = torch.nn.Parameter(torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training))
else:
w = raw_w
ws.append(w)
rsetattr(self.module, name_w, torch.nn.ParameterList(ws))
def forward(self, *args):
self._setweights()
return self.module.forward(*args)