-
Notifications
You must be signed in to change notification settings - Fork 86
/
initialize_weights.m
75 lines (71 loc) · 2.63 KB
/
initialize_weights.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
function [ stack, W_t ] = initialize_weights( eI )
%INITIALIZE_WEIGHTS Random weight structures for a network architecture
% eI describes an RNN via the fields layerSizes, inputDim and
% temporalLayer
%
% This uses Xavier's weight initialization tricks for better backprop
% See: X. Glorot, Y. Bengio. Understanding the difficulty of training
% deep feedforward neural networks. AISTATS 2010.
%% initialize hidden layers
stack = cell(1, numel(eI.layerSizes));
for l = 1 : numel(eI.layerSizes)
if l > 1
prevSize = eI.layerSizes(l-1);
else
prevSize = eI.inputDim;
end;
curSize = eI.layerSizes(l);
% Xaxier's scaling factor
s = sqrt(6) / sqrt(prevSize + curSize);
% Ilya suggests smaller scaling for recurrent layer
if l == eI.temporalLayer
s = sqrt(6) / sqrt(prevSize + 2*curSize);
end;
stack{l}.W = rand(curSize, prevSize)*2*s - s;
stack{l}.b = zeros(curSize, 1);
end
%% weight tying
% default weight tying to false
if ~isfield(eI, 'tieWeights')
eI.tieWeights = 0;
end;
% overwrite decoder layers for tied weights
if eI.tieWeights
decList = [(numel(eI.layerSizes)/2)+1 : numel(eI.layerSizes)-1];
for l = 1:numel(decList)
lDec = decList(l);
lEnc = decList(1) - l;
assert( norm(size(stack{lEnc}.W') - size(stack{lDec}.W)) == 0, ...
'Layersizes dont match for tied weights');
stack{lDec}.W = stack{lEnc}.W';
end;
end;
%% initialize temporal weights if they should exist
W_t = [];
if eI.temporalLayer
% assuems temporal init type set
if strcmpi(eI.temporalInit, 'zero')
W_t = zeros(eI.layerSizes(eI.temporalLayer));
elseif strcmpi(eI.temporalInit, 'rand')
% Ilya's modification to Xavier's update rule
s = sqrt(6) / sqrt(3*eI.layerSizes(eI.temporalLayer));
W_t = rand(eI.layerSizes(eI.temporalLayer))*2*s - s;
elseif strcmpi(eI.temporalInit, 'eye')
W_t = eye(eI.layerSizes(eI.temporalLayer));
else
error('unrecognized temporal initialization: %s', eI.temporalInit);
end;
end;
%% init short circuit connections
% default short circuits to false
if ~isfield(eI, 'shortCircuit')
eI.shortCircuit = 0;
end;
if eI.shortCircuit
%padSize = (eI.winSize-1) / 2;
%stack{end}.W_ss = [zeros(eI.featDim, padSize*eI.featDim), eye(eI.featDim),...
% zeros(eI.featDim, padSize*eI.featDim)];
% use random init since input might contain noise estimate
s = sqrt(6) / sqrt(eI.inputDim + eI.layerSizes(end));
stack{end}.W_ss = rand(eI.inputDim, eI.layerSizes(end))*2*s - s;
end;