-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathwhitening.lua
235 lines (217 loc) · 7.69 KB
/
whitening.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
-- ZCA-Whitening
--
-- Input:
-- - data tensor M x N1 [x N2 x ...] (required); at least 2D.
-- - means: 1D tensor of size N = N1 x N2 x ... (flattned).
-- - P: ZCA-transfor matrix of size N x N.
--
-- Behavior:
-- - if both means and P are provided, the ZCA-transformed data is returned, alongside means and P (unchanged).
-- - otherwise, means and P are computed and returned, preceded by the transformed data.
--
-- Input arguments are never changed.
--
function unsup.zca_whiten(data, means, P, invP, epsilon)
local epsilon = epsilon or 1e-5
local auxdata = data:clone()
local dims = data:size()
local nsamples = dims[1]
local n_dimensions = data:nElement() / nsamples
if data:dim() >= 3 then
auxdata = auxdata:view(nsamples, n_dimensions)
end
if not means or not P or not invP then
-- compute mean vector if not provided
means = torch.mean(auxdata, 1):squeeze()
-- compute transformation matrix P if not provided
local ce, cv = unsup.pcacov(auxdata)
ce:add(epsilon):sqrt()
local invce = ce:clone():pow(-1)
local invdiag = torch.diag(invce)
P = torch.mm(cv, invdiag)
P = torch.mm(P, cv:t())
-- compute inverse of the transformation
local diag = torch.diag(ce)
invP = torch.mm(cv, diag)
invP = torch.mm(invP, cv:t())
end
-- remove the means
local xmeans = means:new():view(1,n_dimensions):expand(nsamples,n_dimensions)
auxdata:add(-1, xmeans)
-- transform in ZCA space
auxdata = torch.mm(auxdata, P)
auxdata:resizeAs(data)
return auxdata, means, P, invP
end
function unsup.zca_colour(data, means, P, invP)
local auxdata = data:clone()
local dims = data:size()
local nsamples = dims[1]
local n_dimensions = data:nElement() / nsamples
assert(means)
assert(invP)
if data:dim() >= 3 then
auxdata = auxdata:view(nsamples, n_dimensions)
end
-- transform in ZCA space
auxdata = torch.mm(auxdata, invP)
-- add back the means
local xmeans = means:new():view(1,n_dimensions):expand(nsamples,n_dimensions)
auxdata:add(xmeans)
auxdata:resizeAs(data)
return auxdata, means, P, invP
end
-- Function computes return a linear layer which applies a ZCA transform
-- to its input using a precomputed (static) transformation matrix.
-- if not specified, the transformation parameters are computed from data
function unsup.zca_layer(data, means, P, invP)
local auxdata
if not means or not P or not invP then
auxdata, means, P, invP = unsup.zca_whiten(data)
end
local n_dimensions = data:nElement() / data:size(1)
local linear = nn.Linear(n_dimensions, n_dimensions)
linear.weight:copy(P:t())
linear.bias:fill(0)
linear.bias:copy(linear:forward(means):mul(-1))
local layer
if data:nDimension() > 2 then
layer = nn.Sequential()
layer:add(nn.Reshape(data:size(1), n_dimensions))
layer:add(linear)
layer:add(nn.Reshape(data:size()))
else
layer = linear
end
return layer, means, P, invP
end
-- Function computes return a linear layer which inverts a ZCA transform
-- of its input using a precomputed (static) transformation matrix.
-- if not specified, the transformation parameters are computed from data
function unsup.inv_zca_layer(data, means, P, invP)
local auxdata
if not means or not P or not invP then
auxdata, means, P, invP = unsup.zca_whiten(data)
end
local n_dimensions = data:nElement() / data:size(1)
local linear = nn.Linear(n_dimensions, n_dimensions)
linear.weight:copy(invP:t())
linear.bias:copy(means)
local layer
if data:nDimension() > 2 then
layer = nn.Sequential()
layer:add(nn.Reshape(data:size(1), n_dimensions))
layer:add(linear)
layer:add(nn.Reshape(data:size()))
else
layer = linear
end
return layer, means, P, invP
end
-- PCA-Whitening
--
-- Input:
-- - data tensor M x N1 [x N2 x ...] (required); at least 2D.
-- - means: 1D tensor of size N = N1 x N2 x ... (flattned).
-- - P: PCA-transfor matrix of size N x N.
--
-- Behavior:
-- - if both means and P are provided, the PCA-transformed data is returned, alongside means, P and invP (unchanged).
-- - otherwise, means, P and invP are computed and returned, preceded by the transformed data.
--
-- Input arguments are never changed.
--
function unsup.pca_whiten(data, means, P, invP)
local auxdata = data:clone()
local dims = data:size()
local nsamples = dims[1]
local n_dimensions = data:nElement() / nsamples
if data:dim() >= 3 then
auxdata = auxdata:view(nsamples, n_dimensions)
end
if not means or not P then
-- compute mean vector if not provided
means = torch.mean(auxdata, 1):squeeze()
-- compute transformation matrix P if not provided
local ce, cv = unsup.pcacov(auxdata)
ce:add(1e-5):sqrt()
local invce = ce:clone():pow(-1)
local invdiag = torch.diag(invce)
P = torch.mm(cv, invdiag)
-- compute inverse of the transformation
local diag = torch.diag(ce)
invP = torch.mm(diag, cv:t())
end
-- remove the means
local xmeans = means:new():view(1,n_dimensions):expand(nsamples,n_dimensions)
auxdata:add(-1, xmeans)
-- transform in ZCA space
auxdata = torch.mm(auxdata, P)
auxdata:resizeAs(data)
return auxdata, means, P, invP
end
function unsup.pca_colour(data, means, P, invP)
local auxdata = data:clone()
local dims = data:size()
local nsamples = dims[1]
local n_dimensions = data:nElement() / nsamples
assert(means)
assert(invP)
if data:dim() >= 3 then
auxdata = auxdata:view(nsamples, n_dimensions)
end
-- transform in PCA space
auxdata = torch.mm(auxdata, invP)
-- add back the means
local xmeans = means:new():view(1,n_dimensions):expand(nsamples,n_dimensions)
auxdata:add(xmeans)
auxdata:resizeAs(data)
return auxdata, means, P, invP
end
-- Function computes return a linear layer which applies a PCA transform
-- to its input using a prec-computed (static) transformation matrix.
-- if not specified, the transformation parameters are computed from data
function unsup.pca_layer(data, means, P, invP)
local auxdata
if not means or not P or not invP then
auxdata, means, P, invP = unsup.pca_whiten(data)
end
local n_dimensions = data:nElement() / data:size(1)
local linear = nn.Linear(n_dimensions, n_dimensions)
linear.weight:copy(P:t())
linear.bias:fill(0)
linear.bias:copy(linear:forward(means):mul(-1))
local layer
if data:nDimension() > 2 then
layer = nn.Sequential()
layer:add(nn.Reshape(data:size(1), n_dimensions))
layer:add(linear)
layer:add(nn.Reshape(data:size()))
else
layer = linear
end
return layer, means, P, invP
end
-- Function computes return a linear layer which inverts a PCA transform
-- of its input using a precomputed (static) transformation matrix.
-- if not specified, the transformation parameters are computed from data
function unsup.inv_pca_layer(data, means, P, invP)
local auxdata
if not means or not P or not invP then
auxdata, means, P, invP = unsup.pca_whiten(data)
end
local n_dimensions = data:nElement() / data:size(1)
local linear = nn.Linear(n_dimensions, n_dimensions)
linear.weight:copy(invP:t())
linear.bias:copy(means)
local layer
if data:nDimension() > 2 then
layer = nn.Sequential()
layer:add(nn.Reshape(data:size(1), n_dimensions))
layer:add(linear)
layer:add(nn.Reshape(data:size()))
else
layer = linear
end
return layer, means, P, invP
end