-
Notifications
You must be signed in to change notification settings - Fork 0
/
update_functions.jl
72 lines (66 loc) · 2.24 KB
/
update_functions.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
function update_u(X, tau, v, v2, precision_type)
if precision_type == "rowwise_constant" || precision_type == "constant"
s2 = 1./ (tau * sum(v2,1));
else
s2 = 1./ (tau * v2);
end
x = ((X .* tau) * v) .* s2;
u = x ./ (s2/sum(x.^2) + 1);
u2 = u.^2 .+ (s2 ./ (s2/sum(x.^2) + 1));
return u, u2
end
function update_v(X, tau, u, u2, precision_type, nv, nullprior; alpha = 0)
if precision_type == "columnwise_constant"
s2 = 1./ (tau' * sum(u2,1));
elseif precision_type == "rowwise_constant"
s2 = repmat(1./ (tau' * u2), size(X,2), 1);
elseif precision_type == "constant"
s2 = repmat(1./ (tau' * sum(u2,1)), size(X,2), 1);
elseif precision_type == "elementwise"
s2 = s2 = 1./ (tau' * u2);
end
x = ((X .* tau)' * u) .* s2;
if alpha == 0
temp = ash(x,s2, nv = nv, nullprior = nullprior);
elseif alpha == 1
temp = ash2(x,s2, nv = nv, nullprior = nullprior);
else
error("Error: \"alpha\" should be 0 or 1");
end
return temp
end
function update_v_group(X, tau, u, u2, precision_type, nv, nullprior; alpha = 0)
if precision_type == "columnwise_constant"
s2 = 1./ (tau' * sum(u2,1));
elseif precision_type == "rowwise_constant"
s2 = repmat(1./ (tau' * u2), size(X,2), 1);
elseif precision_type == "constant"
s2 = repmat(1./ (tau' * sum(u2,1)), size(X,2), 1);
elseif precision_type == "elementwise"
s2 = s2 = 1./ (tau' * u2);
end
x = ((X .* tau)' * u) .* s2;
if alpha == 0
temp = ash(x[:],s2[:], nv = nv, nullprior = nullprior);
elseif alpha == 1
temp = ash2(x[:],s2[:], nv = nv, nullprior = nullprior);
else
error("Error: \"alpha\" should be 0 or 1");
end
return temp
end
function update_tau(R2, precision_type)
if precision_type == "rowwise_constant"
tau = 1./mean(R2,2);
elseif precision_type == "columnwise_constant"
tau = 1./mean(R2,1);
elseif precision_type == "constant"
tau = 1/mean(R2);
elseif precision_type == "elementwise"
tau = 1./R2;
end
end
function update_R2(X, X2, u, u2, v, v2)
# return u2 * v2' + X2 - 2 * (u * v') .* X
return (X - u*v').^2 + u2 * v2' - (u.^2)*(v.^2)'
end