-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_pca_ica.m
142 lines (109 loc) · 4.19 KB
/
run_pca_ica.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
function [data1, aux, isi] = run_pca_ica(X, S, M, num_pc, num_iter, seed, varargin)
% Unimodal: PCA + ICA + CO
rng(seed);
if ~isempty(varargin)
Y = varargin{1};
A = varargin{2};
end
ut = utils;
% Use relative gradient
gradtype = 'relative';
% Enable scale control
sc = 1;
% Turn off preprocessing (still removes the mean of the data)
preX = false;
H_all = cell(1,2);
whtM_all = cell(1,2);
Wr = cell(1,2); % reduced W
for mm = M
[whtM, H] = ut.doMMGPCA(X(mm), num_pc, 'WT');
whtM_all(mm) = whtM;
H_all{mm} = H;
w0 = ut.stackW({diag(pi/sqrt(3)./std(H,[],2))*eye(size(H,1))});
gica1 = MISAK(w0, 1, {eye(size(H,1))}, {H}, ...
0.5*ones(num_pc,1), ones(num_pc,1), ones(num_pc,1), ...
gradtype, sc, preX);
% the whitening matrix is an identity matrix, different from the whitening matrix from PCA
% sphering turns off PCA
[W1,wht] = icatb_runica(H,'weights',gica1.W{1},'ncomps',size(H,1),'sphering', 'off', 'verbose', 'off', 'posact', 'off', 'bias', 'on');
std_W1 = std(W1*H,[],2); % Ignoring wht because Infomax run with 'sphering' 'off' --> wht = eye(comps)
W1 = diag(pi/sqrt(3) ./ std_W1) * W1;
% RUN GICA using MISA: continuing from Infomax above...
% Could use stochastic optimization, but not doing so because MISA does not implement bias weights (yet)...
% gica1.stochastic_opt('verbose', 'off', 'weights', gica1.W{1}, 'bias', 'off');%, 'block', 1100);
[wout,fval,exitflag,output] = ut.run_MISA(gica1,{W1});
std_gica1_W1 = std(gica1.Y{1},[],2);
gica1.objective(ut.stackW({diag(pi/sqrt(3) ./ std_gica1_W1)*gica1.W{1}})); % update gica1.W{1}
Wr{mm} = gica1.W{1};
end
% Combine MISA GICA with whitening matrices to initialize multimodal model
W = cellfun(@(w) w,whtM_all,'Un',0);
W = cellfun(@(w, wr) wr*w,W,Wr,'Un',0);
W = cellfun(@(w,x) diag(pi/sqrt(3) ./ std(w*x,[],2))*w,W,X,'Un',0);
w0_new = ut.stackW(W(M));
% Set Kotz parameters to multivariate laplace
K = size(S{1},1);
eta = ones(K,1);
beta = ones(K,1);
lambda = ones(K,1);
data1 = MISAK(w0_new, M, S, X, ...
0.5*beta, eta, [], ...
gradtype, sc, preX);
for mm = M
W0{mm} = [eye(num_pc),zeros(num_pc,0)];
end
w0_short = ut.stackW(W0);
% 1: data1.Y = data1.W * X
% 2: data2.Y = data2.W * data1.Y
% By 1 and 2: data2.Y = data2.W * data1.W * X
data2 = MISAK(w0_short, data1.M, data1.S, data1.Y, ...
0.5*beta, eta, [], ...
gradtype, sc, preX);
data3 = MISAK(w0_short, data1.M, data1.S, data1.Y, ...
0.5*beta, eta, [], ...
gradtype, false, preX); % turn off scale control
% Prep starting point: optimize RE to ensure initial W is in the feasible region
woutW0 = data2.stackW(data2.W);
% Define objective parameters and run optimization
f = @(x) data2.objective(x);
c = [];
barr = 1; % Barrier parameter
m = 1; % Number of past gradients to use for LBFGS-B (m = 1 is equivalent to conjugate gradient)
N = size(X(M(1)),2); % Number of observations
Tol = .5*N*1e-9; % Tolerance for stopping criteria
isi = zeros(1, num_iter+1);
% Set optimization parameters and run
% Skip fmincon to check effect of ICA (full separation of sources, no residual correlations left)
optprob = ut.getop(woutW0, f, c, barr, {'lbfgs' m}, Tol);
[wout,fval,exitflag,output] = fmincon(optprob);
% Prep and run combinatorial optimization
aux = {data2.W; data2.objective(ut.stackW(data2.W)); data3.objective(ut.stackW(data2.W))};
final_W = cell(1,2);
for mm = M
final_W{mm} = data2.W{mm} * W{mm}; % data2.W is 12x12, W is 12x20k
end
data1.objective(ut.stackW(final_W));
if exist('A','var')
isi(1) = data1.MISI(A)
end
for ct = 2:num_iter+1
data2.combinatorial_optim()
optprob = ut.getop(ut.stackW(data2.W), f, c, barr, {'lbfgs' m}, Tol);
[wout,fval,exitflag,output] = fmincon(optprob);
aux(:,ct) = {data2.W; data2.objective_(); data3.objective(ut.stackW(data2.W))};
final_W = cell(1,2);
for mm = M
final_W{mm} = data2.W{mm} * W{mm}; % data2.W is 12x12, data1.W is 12x20k
end
data1.objective(ut.stackW(final_W));
if exist('A','var')
isi(ct) = data1.MISI(A)
end
end
[~, ix] = min([aux{2,:}]);
final_W = cell(1,2);
for mm = M
final_W{mm} = aux{1,ix}{mm} * W{mm}; % data2.W is 12x12, data1.W is 12x20k
end
data1.objective(ut.stackW(final_W));
end