Skip to content

Commit

Permalink
finish cleanup of log/logdet
Browse files Browse the repository at this point in the history
  • Loading branch information
chriscoey committed Oct 6, 2021
1 parent d871ba1 commit 5a9120b
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 102 deletions.
75 changes: 38 additions & 37 deletions src/Cones/hypoperlog.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,15 @@ function update_hess(cone::HypoPerLog)
v = cone.point[2]
@views w = cone.point[3:end]
H = cone.hess.data
g = cone.grad
ζ = cone.ζ
wivζi = cone.tempw
d = length(w)
σζi = (cone.ϕ - d) / ζ
vζi = v / ζ
@. wivζi = vζi / w

# u, v
H[1, 1] = ζ^-2
H[1, 2] = -σζi / ζ
H[2, 2] = v^-2 + abs2(σζi) + d / ζ / v
H[2, 2] = v^-2 + abs2(σζi) + d / * v)

# u, v, w
vζi2 = -vζi / ζ
Expand All @@ -131,13 +128,14 @@ function update_hess(cone::HypoPerLog)
@. H[2, 3:end] = c1 / w

# w, w
@inbounds for j in eachindex(wivζi)
@inbounds for j in eachindex(w)
j2 = 2 + j
wivζij = wivζi[j]
for i in 1:j
H[2 + i, j2] = wivζi[i] * wivζij
wj = w[j]
c2 = vζi / wj
for i in 1:(j - 1)
H[2 + i, j2] = vζi / w[i] * c2
end
H[j2, j2] -= g[j2] / w[j]
H[j2, j2] = (c2 + c2 * vζi + inv(wj)) / wj
end

cone.hess_updated = true
Expand All @@ -151,24 +149,25 @@ function hess_prod!(
)
v = cone.point[2]
@views w = cone.point[3:end]
ζ = cone.ζ
d = length(w)
ζ = cone.ζ
σ = cone.ϕ - d
rwi = cone.tempw
vζi1 = v / ζ + 1
rwi = cone.tempw

@inbounds for j in 1:size(arr, 2)
p = arr[1, j]
q = arr[2, j]
@. @views rwi = arr[3:end, j] / w

qζi = q / ζ
c0 = sum(rwi) / ζ
# ∇ϕ[r] = v * c0
c1 = (v * c0 - p / ζ + σ * qζi) / ζ
c3 = c1 * v - qζi
prod[1, j] = -c1
prod[2, j] = c1 * σ - c0 + (qζi * d + q / v) / v
c1 = sum(rwi) / ζ
# ∇ϕ[r] = v * c1
c2 = (v * c1 - p / ζ + σ * qζi) / ζ
c3 = c2 * v - qζi

prod[1, j] = -c2
prod[2, j] = c2 * σ - c1 + (qζi * d + q / v) / v
@. prod[3:end, j] = (c3 + vζi1 * rwi) / w
end

Expand All @@ -188,23 +187,23 @@ function update_inv_hess(cone::HypoPerLog)
ζv = ζ + v
ζζvi = ζ / ζv
c1 = v / (ζv + d * v) * v
c3 = c1 / ζv
c2 = c1 / ζv

# u, v
Hi12 = Hi[1, 2] = c1 * (ζv * ϕ - d * ζ)
c2 = (v * ζ + Hi12) / ζv
Hi[1, 1] = ζζvi * d * v^2 + ζ^2 +- ζζvi * d) * Hi12
Hi[2, 2] = c1 * ζv

# u, v, w
@. Hi[1, 3:end] = c2 * w
c3 = (v * ζ + Hi12) / ζv
@. Hi[1, 3:end] = c3 * w
@. Hi[2, 3:end] = c1 * w

# w, w
@inbounds for j in eachindex(w)
j2 = 2 + j
wj = w[j]
c4 = c3 * wj
c4 = c2 * wj
for i in 1:j
Hi[2 + i, j2] = c4 * w[i]
end
Expand All @@ -228,17 +227,18 @@ function inv_hess_prod!(
ϕ = cone.ϕ
ζv = ζ + v
ζζvi = ζ / ζv
rw = cone.tempw
c1 = v / (ζv + d * v) * v
rw = cone.tempw

@inbounds for j in 1:size(arr, 2)
p = arr[1, j]
q = arr[2, j]
@. @views rw = arr[3:end, j] * w
trrw = sum(rw)

trrw = sum(rw)
c2 = c1 * (ζv ** p + q) - d * ζ * p + trrw)
c3 = (c2 + ζ * v * p) / ζv

prod[1, j] = ζ * ((v * (d * p * v + trrw) - d * c2) / ζv + ζ * p) + ϕ * c2
prod[2, j] = c2
@. prod[3:end, j] = (c3 + ζζvi * rw) * w
Expand All @@ -259,24 +259,25 @@ function dder3(cone::HypoPerLog{T}, dir::AbstractVector{T}) where {T <: Real}
σ = cone.ϕ - d
viq = q / v
viq2 = abs2(viq)
rwi = cone.tempw
vζi = v / ζ
vζi1 = vζi + 1
rwi = cone.tempw

@. @views rwi = dir[3:end] / w
c0 = sum(rwi)
c7 = sum(abs2, rwi)
ζiχ = (-p + σ * q + c0 * v) / ζ
c4 = (viq * (-viq * d + 2 * c0) - c7) / ζ / 2
c1 = (abs2(ζiχ) - v * c4) / ζ
c3 = -(ζiχ + viq) / ζ
c5 = c3 * q + vζi * viq2
c6 = -2 * vζi * viq - c3 * v
c8 = c5 + c1 * v

dder3[1] = -c1
dder3[2] = c1 * σ + (viq2 - (d * c5 + c6 * c0 + vζi * c7)) / v - c4
@. dder3[3:end] = (c8 + rwi * (c6 + vζi1 * rwi)) / w
tr1 = sum(rwi)
tr2 = sum(abs2, rwi)

χ = (-p + σ * q + tr1 * v) / ζ
c1 = (viq * (-viq * d + 2 * tr1) - tr2) / ζ / 2
c2 = (abs2(χ) - v * c1) / ζ
c3 = -+ viq) / ζ
c4 = c3 * q + vζi * viq2
c5 = -2 * vζi * viq - c3 * v
c6 = c4 + c2 * v

dder3[1] = -c2
dder3[2] = c2 * σ + (viq2 - (d * c4 + c5 * tr1 + vζi * tr2)) / v - c1
@. dder3[3:end] = (c6 + rwi * (c5 + vζi1 * rwi)) / w

return dder3
end
Expand Down
126 changes: 61 additions & 65 deletions src/Cones/hypoperlogdettri.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ function update_grad(cone::HypoPerLogdetTri)
inv_fact!(cone.Wi, cone.fact_W)
smat_to_svec!(cone.Wi_vec, cone.Wi, cone.rt2)
# ∇ϕ = cone.Wi_vec * v
ζvζi = -1 - v / ζ
@. g[3:end] = ζvζi * cone.Wi_vec
vζi1 = -1 - v / ζ
@. g[3:end] = vζi1 * cone.Wi_vec

cone.grad_updated = true
return cone.grad
Expand All @@ -151,33 +151,31 @@ function update_hess(cone::HypoPerLogdetTri)
H = cone.hess.data
d = cone.d
ζ = cone.ζ
ζi = inv(ζ)
σ = cone.ϕ - d
Wi_vec = cone.Wi_vec
ζiσ = σ / ζ
σζi = (cone.ϕ - d) / ζ
vζi = v / ζ

# u, v
H[1, 1] = abs2(ζi)
H[1, 2] = -ζi * ζiσ
H[2, 2] = v^-2 + abs2(ζiσ) + d / (v * ζ)
H[1, 1] = ζ^-2
H[1, 2] = -σζi / ζ
H[2, 2] = v^-2 + abs2(σζi) + d / (ζ * v)

# u, v, w
c1 = -vζi / ζ
@. H[1, 3:end] = c1 * Wi_vec
c2 = * vζi - 1) / ζ
@. H[2, 3:end] = c2 * Wi_vec
vζi2 = -vζi / ζ
c1 = ((cone.ϕ - d) * vζi - 1) / ζ
@. H[1, 3:end] = vζi2 * Wi_vec
@. H[2, 3:end] = c1 * Wi_vec

# w, w
copytri!(cone.Wi, 'U', true)
@views symm_kron!(H[3:end, 3:end], cone.Wi, cone.rt2)

@inbounds for j in eachindex(Wi_vec)
j2 = 2 + j
c3 = vζi * Wi_vec[j]
c2 = vζi * Wi_vec[j]
for i in 1:j
i2 = 2 + i
H[i2, j2] += vζi * (H[i2, j2] + c3 * Wi_vec[i])
H[i2, j2] += vζi * (H[i2, j2] + c2 * Wi_vec[i])
end
end

Expand All @@ -192,12 +190,12 @@ function hess_prod!(
)
@assert cone.grad_updated
v = cone.point[2]
FU = cone.fact_W.U
d = cone.d
ζ = cone.ζ
w_aux = cone.mat3
FU = cone.fact_W.U
σ = cone.ϕ - d
vζi1 = v / ζ + 1
w_aux = cone.mat3

@inbounds for j in 1:size(arr, 2)
p = arr[1, j]
Expand All @@ -208,20 +206,20 @@ function hess_prod!(
ldiv!(FU', w_aux)

qζi = q / ζ
c0 = tr(Hermitian(w_aux, :U)) / ζ
# ∇ϕ[r] = v * c0
c1 = (v * c0 - p / ζ + σ * qζi) / ζ
c3 = c1 * v - qζi
prod[1, j] = -c1
prod[2, j] = c1 * σ - c0 + (qζi * d + q / v) / v
c1 = tr(Hermitian(w_aux, :U)) / ζ
# ∇ϕ[r] = v * c1
c2 = (v * c1 - p / ζ + σ * qζi) / ζ
c3 = c2 * v - qζi

prod[1, j] = -c2
prod[2, j] = c2 * σ - c1 + (qζi * d + q / v) / v

lmul!(vζi1, w_aux)
for i in diagind(w_aux)
w_aux[i] += c3
end
rdiv!(w_aux, FU')
ldiv!(FU, w_aux)

@views smat_to_svec!(prod[3:end, j], w_aux, cone.rt2)
end

Expand All @@ -242,22 +240,23 @@ function update_inv_hess(cone::HypoPerLogdetTri)
ϕ = cone.ϕ
ζv = ζ + v
ζζvi = ζ / ζv
c3 = v / (ζv + d * v)
c0 = ϕ - d * ζζvi
c2 = v * c3
c4 = c2 * ζv
c1 = v * ζζvi + c0 * c2
c1 = v / (ζv + d * v) * v
c2 = c1 / ζv

Hi[1, 1] = abs2(v * ϕ) + ζ *+ d * v) - d * abs2+ v * ϕ) * c3
Hi[1, 2] = c0 * c4
Hi[2, 2] = c4
# u, v
Hi12 = Hi[1, 2] = c1 * (ζv * ϕ - d * ζ)
Hi[1, 1] = ζζvi * d * v^2 + ζ^2 +- ζζvi * d) * Hi12
Hi[2, 2] = c1 * ζv

@. Hi[1, 3:end] = c1 * w
@. Hi[2, 3:end] = c2 * w
# u, v, w
c3 = (v * ζ + Hi12) / ζv
@. Hi[1, 3:end] = c3 * w
@. Hi[2, 3:end] = c1 * w

# w, w
@views Hiww = Hi[3:end, 3:end]
symm_kron!(Hiww, W, cone.rt2)
mul!(Hiww, w, w', c2 / ζv, ζζvi)
mul!(Hiww, w, w', c2, ζζvi)

cone.inv_hess_updated = true
return cone.inv_hess
Expand All @@ -278,12 +277,7 @@ function inv_hess_prod!(
ϕ = cone.ϕ
ζv = ζ + v
ζζvi = ζ / ζv
c3 = v / (ζv + d * v)
c0 = ϕ - d * ζζvi
c4 = v * c3 * ζv
c6 = abs2(v * ϕ) + ζ *+ d * v) - d * abs2+ v * ϕ) * c3
c7 = c4 * c0
c8 = c7 + v * ζ
c1 = v / (ζv + d * v) * v
w_aux = cone.mat2
w_aux2 = cone.mat3

Expand All @@ -294,17 +288,18 @@ function inv_hess_prod!(
svec_to_smat!(w_aux, r, cone.rt2)
copytri!(w_aux, 'U', true)

c1 = dot(w, r) / ζv
c5 = c0 * p + q + c1
c2 = v * (ζζvi * p + c3 * c5)
prod[1, j] = c6 * p + c7 * q + c8 * c1
prod[2, j] = c4 * c5
trrw = dot(w, r)
c2 = c1 * (ζv ** p + q) - d * ζ * p + trrw)
c3 = (c2 + ζ * v * p) / ζv

prod[1, j] = ζ * ((v * (d * p * v + trrw) - d * c2) / ζv + ζ * p) + ϕ * c2
prod[2, j] = c2

mul!(w_aux2, w_aux, W)
mul!(w_aux, W, w_aux2)
@views prod_w = prod[3:end, j]
smat_to_svec!(prod_w, w_aux, cone.rt2)
axpby!(c2, w, ζζvi, prod_w)
axpby!(c3, w, ζζvi, prod_w)
end

return prod
Expand All @@ -316,41 +311,42 @@ function dder3(cone::HypoPerLogdetTri, dir::AbstractVector)
dder3 = cone.dder3
p = dir[1]
q = dir[2]
@views r = dir[3:end]
d = cone.d
ζ = cone.ζ
FU = cone.fact_W.U
rwi = cone.mat2
w_aux = cone.mat3
w_aux2 = cone.mat4
σ = cone.ϕ - d
viq = q / v
viq2 = abs2(viq)
vζi = v / ζ
vζi1 = vζi + 1
rwi = cone.mat2
w_aux = cone.mat3
w_aux2 = cone.mat4

svec_to_smat!(rwi, r, cone.rt2)
@views svec_to_smat!(rwi, dir[3:end], cone.rt2)
copytri!(rwi, 'U', true)
rdiv!(rwi, FU)
ldiv!(FU', rwi)
c0 = tr(Hermitian(rwi, :U))
c7 = sum(abs2, rwi)
ζiχ = (-p + σ * q + c0 * v) / ζ
c4 = (viq * (-viq * d + 2 * c0) - c7) / ζ / 2
c1 = (abs2(ζiχ) - v * c4) / ζ
c3 = -(ζiχ + viq) / ζ
c5 = c3 * q + vζi * viq2
c6 = -2 * vζi * viq - c3 * v
c8 = c5 + c1 * v

dder3[1] = -c1
dder3[2] = c1 * σ + (viq2 - (d * c5 + c6 * c0 + vζi * c7)) / v - c4

tr1 = tr(Hermitian(rwi, :U))
tr2 = sum(abs2, rwi)

χ = (-p + σ * q + tr1 * v) / ζ
c1 = (viq * (-viq * d + 2 * tr1) - tr2) / ζ / 2
c2 = (abs2(χ) - v * c1) / ζ
c3 = -+ viq) / ζ
c4 = c3 * q + vζi * viq2
c5 = -2 * vζi * viq - c3 * v
c6 = c4 + c2 * v

dder3[1] = -c2
dder3[2] = c2 * σ + (viq2 - (d * c4 + c5 * tr1 + vζi * tr2)) / v - c1

copyto!(w_aux2, I)
axpby!(vζi1, rwi, c6, w_aux2)
axpby!(vζi1, rwi, c5, w_aux2)
mul!(w_aux, Hermitian(rwi, :U), w_aux2)
@inbounds for i in diagind(w_aux)
w_aux[i] += c8
w_aux[i] += c6
end
rdiv!(w_aux, FU')
ldiv!(FU, w_aux)
Expand Down

0 comments on commit 5a9120b

Please sign in to comment.