-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathnninit.lua
364 lines (296 loc) · 11.5 KB
/
nninit.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
local nn = require 'nn'
local hasSignal, signal = pcall(require, 'signal')
-- Helper functions
-- Calculates fan in and fan out of module
local function calcFan(module)
local typename = torch.type(module)
if typename == 'nn.Linear' or
typename == 'nn.LinearNoBias' or
typename == 'nn.LookupTable' then
return module.weight:size(2), module.weight:size(1)
elseif typename:find('TemporalConvolution') then
return module.weight:size(2), module.weight:size(1)
elseif typename:find('SpatialConvolution') or typename:find('SpatialFullConvolution') then
return module.nInputPlane * module.kW * module.kH, module.nOutputPlane * module.kW * module.kH
elseif typename:find('VolumetricConvolution') or typename:find('VolumetricFullConvolution') then
return module.nInputPlane * module.kT * module.kW * module.kH, module.nOutputPlane * module.kT * module.kW * module.kH
else
error("Unsupported module")
end
end
-- Returns the gain or calculates if given a gain type (with optional args)
local function calcGain(gain)
-- Return gain if a number already
if type(gain) == 'number' then
return gain
end
-- Extract gain string if table
local args
if type(gain) == 'table' then
args = gain
gain = gain[1]
end
-- Process gain strings with optional args
if gain == 'linear' or gain == 'sigmoid' then
return 1
elseif gain == 'tanh' then
return 5 / 3
elseif gain == 'relu' then
return math.sqrt(2)
elseif gain == 'lrelu' then
return math.sqrt(2 / (1 + math.pow(args.leakiness, 2)))
end
-- Return 1 by default
return 1
end
-- init method
-- Add init to nn.Module
nn.Module.init = function(self, accessor, initialiser, ...)
-- Extract tensor to initialise
local tensor
if type(accessor) == 'string' then
tensor = self[accessor]
elseif type(accessor) == 'table' then
tensor = self[accessor[1]][accessor[2]]
elseif type(accessor) == 'function' then
tensor = accessor(self)
else
error("Unsupported accessor")
end
-- Initialise tensor (given module and options)
initialiser(self, tensor, ...)
-- Return module for chaining
return self
end
-- nninit
local nninit = {}
-- Copies another tensor to the tensor to be initialised
nninit.copy = function(module, tensor, init)
tensor:copy(init)
return module
end
-- Fills tensor with a constant value
nninit.constant = function(module, tensor, val)
tensor:fill(val)
return module
end
-- Adds to current tensor with a constant value
nninit.addConstant = function(module, tensor, val)
tensor:add(val)
return module
end
-- Multiplies current tensor by a constant value
nninit.mulConstant = function(module, tensor, val)
tensor:mul(val)
return module
end
-- Fills tensor ~ N(mean, stdv)
nninit.normal = function(module, tensor, mean, stdv)
tensor:normal(mean, stdv)
return module
end
-- Adds to current tensor with ~ N(mean, stdv)
nninit.addNormal = function(module, tensor, mean, stdv)
tensor:add(torch.Tensor(tensor:size()):normal(mean, stdv))
return module
end
-- Fills tensor ~ U(a, b)
nninit.uniform = function(module, tensor, a, b)
tensor:uniform(a, b)
return module
end
-- Adds to current tensor with ~ U(a, b)
nninit.addUniform = function(module, tensor, a, b)
tensor:add(torch.Tensor(tensor:size()):uniform(a, b))
return module
end
-- Fills weights with the identity matrix (for linear layers)
-- Fills filters with the Dirac delta function (for convolutional layers)
-- TODO: Generalise for arbitrary tensors?
nninit.eye = function(module, tensor)
if module.weight ~= tensor then
error("nninit.eye only supports 'weight' tensor")
end
local typename = torch.type(module)
if typename == 'nn.Linear' or
typename == 'nn.LinearNoBias' or
typename == 'nn.LookupTable' then
local I = torch.eye(tensor:size(1), tensor:size(2))
tensor:copy(I)
elseif typename:find('TemporalConvolution') then
tensor:zero()
for i = 1, module.inputFrameSize do
tensor[{{}, {(i-1)*module.kW + math.ceil(module.kW/2)}}]:fill(1/module.inputFrameSize)
end
elseif typename:find('SpatialConvolution') or typename:find('SpatialFullConvolution') then
tensor:zero():view(module.nInputPlane, module.nOutputPlane, module.kW, module.kH)[{{}, {}, math.ceil(module.kW/2), math.ceil(module.kH/2)}]:fill(1/module.nInputPlane)
elseif typename:find('VolumetricConvolution') or typename:find('VolumetricFullConvolution') then
tensor:zero():view(module.nInputPlane, module.nOutputPlane, module.kT, module.kW, module.kH)[{{}, {}, math.ceil(module.kT/2), math.ceil(module.kW/2), math.ceil(module.kH/2)}]:fill(1/module.nInputPlane)
else
error("Unsupported module for 'eye'")
end
return module
end
--[[
-- Glorot, X., & Bengio, Y. (2010)
-- Understanding the difficulty of training deep feedforward neural networks
-- In International Conference on Artificial Intelligence and Statistics
--
-- Also known as Glorot initialisation
--]]
nninit.xavier = function(module, tensor, options)
local fanIn, fanOut = calcFan(module)
options = options or {}
gain = calcGain(options.gain)
dist = options.dist or 'uniform' -- Uniform by default
local stdv = gain * math.sqrt(2 / (fanIn + fanOut))
if dist == 'uniform' then
local b = stdv * math.sqrt(3)
tensor:uniform(-b, b)
elseif dist == 'normal' then
tensor:normal(0, stdv)
end
return module
end
--[[
-- He, K., Zhang, X., Ren, S., & Sun, J. (2015)
-- Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification
-- arXiv preprint arXiv:1502.01852
--
-- Also known as He initialisation
--]]
nninit.kaiming = function(module, tensor, options)
local fanIn = calcFan(module)
options = options or {}
gain = calcGain(options.gain)
dist = options.dist or 'normal' -- Normal by default
local stdv = gain * math.sqrt(1 / fanIn)
if dist == 'uniform' then
local b = stdv * math.sqrt(3)
tensor:uniform(-b, b)
elseif dist == 'normal' then
tensor:normal(0, stdv)
end
return module
end
--[[
-- Saxe, A. M., McClelland, J. L., & Ganguli, S. (2013)
-- Exact solutions to the nonlinear dynamics of learning in deep linear neural networks
-- arXiv preprint arXiv:1312.6120
--]]
nninit.orthogonal = function(module, tensor, options)
local sizes = tensor:size()
if #sizes < 2 then
error("nninit.orthogonal only supports tensors with 2 or more dimensions")
end
-- Calculate "fan in" and "fan out" for arbitrary tensors based on module conventions
local fanIn = sizes[2]
local fanOut = sizes[1]
for d = 3, #sizes do
fanIn = fanIn * sizes[d]
end
options = options or {}
gain = calcGain(options.gain)
-- Construct random matrix
local randMat = torch.Tensor(fanOut, fanIn):normal(0, 1)
local U, __, V = torch.svd(randMat, 'S')
-- Pick out orthogonal matrix
local W
if fanOut > fanIn then
W = U
else
W = V:narrow(1, 1, fanOut)
end
-- Resize
W:resize(tensor:size())
-- Multiply by gain
W:mul(gain)
tensor:copy(W)
return module
end
--[[
-- Martens, J. (2010)
-- Deep learning via Hessian-free optimization
-- In Proceedings of the 27th International Conference on Machine Learning (ICML-10)
--]]
nninit.sparse = function(module, tensor, sparsity)
local nElements = tensor:nElement()
local nSparseElements = math.floor(sparsity * nElements)
local randIndices = torch.randperm(nElements):long()
local sparseIndices = randIndices:narrow(1, 1, nSparseElements)
-- Zero out selected indices
tensor:view(nElements):indexFill(1, sparseIndices, 0)
return module
end
--[[
-- Aghajanyan, A. (2017)
-- Convolution Aware Initialization
-- arXiv preprint arXiv:1702.06295
--]]
nninit.convolutionAware = function(module, tensor, options)
-- The author of the paper provided a reference implementation for Keras: https://github.com/farizrahman4u/keras-contrib/pull/60
-- Make sure that the signal library is available, which provides the Fourier transform
if hasSignal == false then
error("nninit.convolutionAware requires the signal library, please make sure to install it: https://github.com/soumith/torch-signal")
end
-- Check the size of the convolution tensor, right now, only 2d convolution tensors are supported
local sizes = tensor:size()
if #sizes ~= 4 then
error("nninit.convolutionAware only supports 2d convolutions, feel free to issue a pull request to extend this implementation")
end
-- Store the sizes of the convolution tensor to make the implementation easier to read
local filterCount = sizes[1]
local filterStacks = sizes[2]
local filterRows = sizes[3]
local filterCols = sizes[4]
-- Due to the irfft2 interface of the signal library, we currently have to restrict the filter size
if filterRows ~= filterCols then
error("nninit.convolutionAware requires the filters to have the same number of rows and columns, feel free to issue a pull request to extend this implementation")
end
-- Calculate "fanIn" and "fanOut" for 2d convolution tensors based on module conventions
local fanIn = filterStacks * filterRows * filterCols
local fanOut = filterCount * filterRows * filterCols
-- Setup options where "std" specifies the noise to break symmetry in the inverse Fourier transform
options = options or {}
gain = calcGain(options.gain)
std = options.std or 0.05
-- Specify the variables for the frequency domain tensor
local fourierTensor = signal.rfft2(torch.zeros(filterRows, filterCols))
local fourierRows = fourierTensor:size(1)
local fourierCols = fourierTensor:size(2)
local fourierSize = fourierRows * fourierCols
-- Specify the variables for the orthogonal tensor buffer
local orthogonalIndex = fourierSize
local orthogonalTensor = nil
-- For each filter, create a suitable basis tensor and perform an inverse Fourier transform to obtain the filter coefficients
for filterIndex = 1, filterCount do
basisTensor = torch.zeros(filterStacks, fourierSize)
-- Create a suitable basis tensor using the orthogonal tensor buffer, making sure to refill the buffer should it be empty
for basisIndex = 1, filterStacks do
if orthogonalIndex == fourierSize then
local randomTensor = torch.zeros(fourierSize, fourierSize):normal(0.0, 1.0)
local symmetricTensor = randomTensor + randomTensor:t() - torch.diag(randomTensor:diag())
orthogonalIndex = 0
orthogonalTensor, _, _ = torch.svd(symmetricTensor)
end
-- Copy a column from the orthogonal tensor buffer into the basis tensor
orthogonalIndex = orthogonalIndex + 1
basisTensor[{ { basisIndex }, {} }] = orthogonalTensor[{ {}, { orthogonalIndex } }]
end
basisTensor = basisTensor:view(filterStacks, fourierRows, fourierCols)
-- Perform the inverse Fourier transform from the basis tensor to obtain the filter coefficients, making sure to break the symmetry
for basisIndex = 1, filterStacks do
fourierTensor[{ {}, {}, { 1 } }] = basisTensor[{ { basisIndex }, {} }]
fourierTensor[{ {}, {}, { 2 } }]:zero()
-- Unlike the Numpy implementation, the inverse Fourier transform in the signal library does sadly only support a single size argument
tensor[{ { filterIndex }, { basisIndex }, {}, {} }] = signal.irfft2(fourierTensor, filterRows) + torch.zeros(filterRows, filterCols):normal(0.0, std)
end
-- Clear the orthogonal tensor buffer, we do not want to reuse it for the next filter
orthogonalIndex = fourierSize
orthogonalTensor = nil
end
-- Scale the filter variance to match the variance scheme defined by He-normal initialization
tensor:mul(gain * torch.sqrt((1.0 / fanIn) * (1.0 / tensor:var())))
return module
end
return nninit