-
Notifications
You must be signed in to change notification settings - Fork 11
/
viModel.py
254 lines (183 loc) · 8.88 KB
/
viModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon May 17 13:05:55 2021
@author: laurent
"""
import numpy as np
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch.distributions.normal import Normal
class VIModule(nn.Module) :
"""
A mixin class to attach loss functions to layer. This is usefull when doing variational inference with deep learning.
"""
def __init__(self, *args, **kwargs) :
super().__init__(*args, **kwargs)
self._internalLosses = []
self.lossScaleFactor = 1
def addLoss(self, func) :
self._internalLosses.append(func)
def evalLosses(self) :
t_loss = 0
for l in self._internalLosses :
t_loss = t_loss + l(self)
return t_loss
def evalAllLosses(self) :
t_loss = self.evalLosses()*self.lossScaleFactor
for m in self.children() :
if isinstance(m, VIModule) :
t_loss = t_loss + m.evalAllLosses()*self.lossScaleFactor
return t_loss
class MeanFieldGaussianFeedForward(VIModule) :
"""
A feed forward layer with a Gaussian prior distribution and a Gaussian variational posterior.
"""
def __init__(self,
in_features,
out_features,
bias = True,
groups=1,
weightPriorMean = 0,
weightPriorSigma = 1.,
biasPriorMean = 0,
biasPriorSigma = 1.,
initMeanZero = False,
initBiasMeanZero = False,
initPriorSigmaScale = 0.01) :
super(MeanFieldGaussianFeedForward, self).__init__()
self.samples = {'weights' : None, 'bias' : None, 'wNoiseState' : None, 'bNoiseState' : None}
self.in_features = in_features
self.out_features = out_features
self.has_bias = bias
self.weights_mean = Parameter((0. if initMeanZero else 1.)*(torch.rand(out_features, int(in_features/groups))-0.5))
self.lweights_sigma = Parameter(torch.log(initPriorSigmaScale*weightPriorSigma*torch.ones(out_features, int(in_features/groups))))
self.noiseSourceWeights = Normal(torch.zeros(out_features, int(in_features/groups)),
torch.ones(out_features, int(in_features/groups)))
self.addLoss(lambda s : 0.5*s.getSampledWeights().pow(2).sum()/weightPriorSigma**2)
self.addLoss(lambda s : -self.out_features/2*np.log(2*np.pi) - 0.5*s.samples['wNoiseState'].pow(2).sum() - s.lweights_sigma.sum())
if self.has_bias :
self.bias_mean = Parameter((0. if initBiasMeanZero else 1.)*(torch.rand(out_features)-0.5))
self.lbias_sigma = Parameter(torch.log(initPriorSigmaScale*biasPriorSigma*torch.ones(out_features)))
self.noiseSourceBias = Normal(torch.zeros(out_features), torch.ones(out_features))
self.addLoss(lambda s : 0.5*s.getSampledBias().pow(2).sum()/biasPriorSigma**2)
self.addLoss(lambda s : -self.out_features/2*np.log(2*np.pi) - 0.5*s.samples['bNoiseState'].pow(2).sum() - self.lbias_sigma.sum())
def sampleTransform(self, stochastic=True) :
self.samples['wNoiseState'] = self.noiseSourceWeights.sample().to(device=self.weights_mean.device)
self.samples['weights'] = self.weights_mean + (torch.exp(self.lweights_sigma)*self.samples['wNoiseState'] if stochastic else 0)
if self.has_bias :
self.samples['bNoiseState'] = self.noiseSourceBias.sample().to(device=self.bias_mean.device)
self.samples['bias'] = self.bias_mean + (torch.exp(self.lbias_sigma)*self.samples['bNoiseState'] if stochastic else 0)
def getSampledWeights(self) :
return self.samples['weights']
def getSampledBias(self) :
return self.samples['bias']
def forward(self, x, stochastic=True) :
self.sampleTransform(stochastic=stochastic)
return nn.functional.linear(x, self.samples['weights'], bias = self.samples['bias'] if self.has_bias else None)
class MeanFieldGaussian2DConvolution(VIModule) :
"""
A Bayesian module that fit a posterior gaussian distribution on a 2D convolution module with normal prior.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode='zeros',
wPriorSigma = 1.,
bPriorSigma = 1.,
initMeanZero = False,
initBiasMeanZero = False,
initPriorSigmaScale = 0.01) :
super(MeanFieldGaussian2DConvolution, self).__init__()
self.samples = {'weights' : None, 'bias' : None, 'wNoiseState' : None, 'bNoiseState' : None}
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.has_bias = bias
self.padding_mode = padding_mode
self.weights_mean = Parameter((0. if initMeanZero else 1.)*(torch.rand(out_channels, int(in_channels/groups), self.kernel_size[0], self.kernel_size[1])-0.5))
self.lweights_sigma = Parameter(torch.log(initPriorSigmaScale*wPriorSigma*torch.ones(out_channels, int(in_channels/groups), self.kernel_size[0], self.kernel_size[1])))
self.noiseSourceWeights = Normal(torch.zeros(out_channels, int(in_channels/groups), self.kernel_size[0], self.kernel_size[1]),
torch.ones(out_channels, int(in_channels/groups), self.kernel_size[0], self.kernel_size[1]))
self.addLoss(lambda s : 0.5*s.getSampledWeights().pow(2).sum()/wPriorSigma**2)
self.addLoss(lambda s : -self.out_channels/2*np.log(2*np.pi) - 0.5*s.samples['wNoiseState'].pow(2).sum() - s.lweights_sigma.sum())
if self.has_bias :
self.bias_mean = Parameter((0. if initBiasMeanZero else 1.)*(torch.rand(out_channels)-0.5))
self.lbias_sigma = Parameter(torch.log(initPriorSigmaScale*bPriorSigma*torch.ones(out_channels)))
self.noiseSourceBias = Normal(torch.zeros(out_channels), torch.ones(out_channels))
self.addLoss(lambda s : 0.5*s.getSampledBias().pow(2).sum()/bPriorSigma**2)
self.addLoss(lambda s : -self.out_channels/2*np.log(2*np.pi) - 0.5*s.samples['bNoiseState'].pow(2).sum() - self.lbias_sigma.sum())
def sampleTransform(self, stochastic=True) :
self.samples['wNoiseState'] = self.noiseSourceWeights.sample().to(device=self.weights_mean.device)
self.samples['weights'] = self.weights_mean + (torch.exp(self.lweights_sigma)*self.samples['wNoiseState'] if stochastic else 0)
if self.has_bias :
self.samples['bNoiseState'] = self.noiseSourceBias.sample().to(device=self.bias_mean.device)
self.samples['bias'] = self.bias_mean + (torch.exp(self.lbias_sigma)*self.samples['bNoiseState'] if stochastic else 0)
def getSampledWeights(self) :
return self.samples['weights']
def getSampledBias(self) :
return self.samples['bias']
def forward(self, x, stochastic=True) :
self.sampleTransform(stochastic=stochastic)
if self.padding != 0 and self.padding != (0,0) :
padkernel = (self.padding, self.padding, self.padding, self.padding) if isinstance(self.padding, int) else (self.padding[1], self.padding[1], self.padding[0], self.padding[0])
mx = nn.functional.pad(x, padkernel, mode=self.padding_mode, value=0)
else :
mx = x
return nn.functional.conv2d(mx,
self.samples['weights'],
bias = self.samples['bias'] if self.has_bias else None,
stride= self.stride,
padding=0,
dilation=self.dilation,
groups=self.groups)
class BayesianMnistNet(VIModule):
def __init__(self,
convWPriorSigma = 1.,
convBPriorSigma = 5.,
linearWPriorSigma = 1.,
linearBPriorSigma = 5.,
p_mc_dropout = 0.5) :
super().__init__()
self.p_mc_dropout = p_mc_dropout
self.conv1 = MeanFieldGaussian2DConvolution(1, 16,
wPriorSigma = convWPriorSigma,
bPriorSigma = convBPriorSigma,
kernel_size=5,
initPriorSigmaScale=1e-7)
self.conv2 = MeanFieldGaussian2DConvolution(16, 32,
wPriorSigma = convWPriorSigma,
bPriorSigma = convBPriorSigma,
kernel_size=5,
initPriorSigmaScale=1e-7)
self.linear1 = MeanFieldGaussianFeedForward(512, 128,
weightPriorSigma = linearWPriorSigma,
biasPriorSigma = linearBPriorSigma,
initPriorSigmaScale=1e-7)
self.linear2 = MeanFieldGaussianFeedForward(128, 10,
weightPriorSigma = linearWPriorSigma,
biasPriorSigma = linearBPriorSigma,
initPriorSigmaScale=1e-7)
def forward(self, x, stochastic=True):
x = nn.functional.relu(nn.functional.max_pool2d(self.conv1(x, stochastic=stochastic), 2))
x = self.conv2(x, stochastic=stochastic)
if self.p_mc_dropout is not None :
x = nn.functional.dropout2d(x, p = self.p_mc_dropout, training=stochastic) #MC-Dropout
x = nn.functional.relu(nn.functional.max_pool2d(x, 2))
x = x.view(-1, 512)
x = nn.functional.relu(self.linear1(x, stochastic=stochastic))
if self.p_mc_dropout is not None :
x = nn.functional.dropout(x, p = self.p_mc_dropout, training=stochastic) #MC-Dropout
x = self.linear2(x, stochastic=stochastic)
return nn.functional.log_softmax(x, dim=-1)