-
Notifications
You must be signed in to change notification settings - Fork 3
/
run_pred_step.m
44 lines (28 loc) · 1.2 KB
/
run_pred_step.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
function [testyhat, testys2] = run_pred_step(kern, data_var, kern_hyp, sis, ...
PredModel, trainx, testx, nt)
if strcmp(kern.type, 'histIntKern')
testXCells = kern.cellDataPrepFunc(testx);
iXcells = kern.cellDataPrepFunc(trainx(sis,:));
K_star = data_var * kern.KFunc(iXcells, testXCells, kern_hyp);
K_starstar = data_var * kern.diagKFunc(testXCells, nt, kern_hyp);
elseif strcmp(kern.type, 'SQExpKern')
K_star = data_var * kern.KFunc(trainx(sis,:), testx, kern_hyp);
K_starstar = data_var * kern.diagKFunc(testx, nt, kern_hyp);
elseif strcmp(kern.type, 'preComputedKern')
if iscell(testx)
K_star = 0;
K_starstar = 0;
for jj = 1:length(testx)
K_star = K_star + kern_hyp(jj)*testx{jj}.Knt(sis, :);
K_starstar = K_starstar + kern_hyp(jj)*testx{jj}.Ktt;
end
else
K_star = kern_hyp * testx.Knt(sis, :);
K_starstar = kern_hyp * testx.Ktt;
end
else
error('unsupported');
end
[testyhat, testys2] = dtc_pred(K_star, K_starstar, PredModel, false);
testys2 = safeGuardPosValues(testys2);
end