Skip to content

Commit

Permalink
cleanup geomean cone
Browse files Browse the repository at this point in the history
  • Loading branch information
chriscoey committed Oct 4, 2021
1 parent 8ab90eb commit 6da5ec0
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 58 deletions.
103 changes: 53 additions & 50 deletions src/Cones/hypogeomean.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ mutable struct HypoGeoMean{T <: Real} <: Cone{T}
di::T
ϕ::T
ζ::T
ϕζidi::T
η::T
tempw::Vector{T}

function HypoGeoMean{T}(
Expand Down Expand Up @@ -65,12 +65,11 @@ end

function update_feas(cone::HypoGeoMean{T}) where {T <: Real}
@assert !cone.feas_updated
u = cone.point[1]
@views w = cone.point[2:end]

if all(>(eps(T)), w)
cone.ϕ = exp(cone.di * sum(log, w))
cone.ζ = cone.ϕ - u
cone.ζ = cone.ϕ - cone.point[1]
cone.is_feas = (cone.ζ > eps(T))
else
cone.is_feas = false
Expand All @@ -85,7 +84,7 @@ function is_dual_feas(cone::HypoGeoMean{T}) where {T <: Real}
@views w = cone.dual_point[2:end]

if (u < -eps(T)) && all(>(eps(T)), w)
return (length(w) * exp(cone.di * sum(log, w)) + u > eps(T))
return (sum(log, w) - length(w) * log(-u * cone.di) > eps(T))
end

return false
Expand All @@ -96,11 +95,11 @@ function update_grad(cone::HypoGeoMean)
@views w = cone.point[2:end]
g = cone.grad
ζ = cone.ζ
cone.ϕζidi = cone.ϕ / ζ * cone.di
ϕζidi1 = -cone.ϕζidi - 1
cone.η = cone.ϕ / ζ * cone.di
= -1 - cone.η

g[1] = inv(ζ)
@. g[2:end] = ϕζidi1 / w
@. g[2:end] = / w

cone.grad_updated = true
return cone.grad
Expand All @@ -112,23 +111,23 @@ function update_hess(cone::HypoGeoMean)
@views w = cone.point[2:end]
H = cone.hess.data
ζ = cone.ζ
ϕζidi = cone.ϕζidi
c4 = ϕζidi - cone.di
c1 = ϕζidi * (1 + c4) + 1
η = cone.η
c1 = η - cone.di
c2 = η * (1 + c1) + 1

H[1, 1] = ζ^-2

@inbounds for j in eachindex(w)
j1 = j + 1
j1 = 1 + j
w_j = w[j]
c3 = ϕζidi / w_j
c3 = η / w_j
H[1, j1] = -c3 / ζ

c2 = c3 * c4
c4 = c3 * c1
for i in 1:(j - 1)
H[i + 1, j1] = c2 / w[i]
H[i + 1, j1] = c4 / w[i]
end
H[j1, j1] = c1 / w_j / w_j
H[j1, j1] = c2 / w_j / w_j
end

cone.hess_updated = true
Expand All @@ -144,20 +143,21 @@ function hess_prod!(
@views w = cone.point[2:end]
di = cone.di
ζ = cone.ζ
ϕζidi = cone.ϕζidi
η = cone.η
θ = 1 + η
rwi = cone.tempw
ϕζidi1 = ϕζidi + 1

@inbounds for j in 1:size(arr, 2)
p = arr[1, j]
@views r = arr[2:end, j]
@. rwi = r / w

c0 = ϕζidi * sum(rwi)
c1 = c0 - p / ζ
prod[1, j] = c1 / -ζ
c2 = ϕζidi * c1 - di * c0
@. prod[2:end, j] = (c2 + ϕζidi1 * rwi) / w
c1 = η * sum(rwi)
χ = c1 - p / ζ
ητ = η * χ - di * c1

prod[1, j] = χ / -ζ
@. prod[2:end, j] = (ητ + θ * rwi) / w
end

return prod
Expand All @@ -173,23 +173,24 @@ function update_inv_hess(cone::HypoGeoMean{T}) where {T <: Real}
ζ = cone.ζ
ϕ = cone.ϕ
di = cone.di
ϕdi = ϕ * di
c2 = inv(cone.ϕζidi + 1)
c3 = c2 / ζ * di
η = cone.η
diϕ = ϕ * di
θi = inv(1 + η)
c1 = θi / ζ * di

Hi[1, 1] = abs2(ζ) + ϕdi * ϕ
Hi[1, 1] = abs2(ζ) + diϕ * ϕ

@inbounds for j in eachindex(w)
j1 = j + 1
w_j = w[j]
c5 = ϕdi * w_j
Hi[1, j1] = c5
c2 = diϕ * w_j
Hi[1, j1] = c2

c4 = c3 * c5
c3 = c1 * c2
for i in 1:(j - 1)
Hi[i + 1, j1] = c4 * w[i]
Hi[i + 1, j1] = c3 * w[i]
end
Hi[j1, j1] = (c4 + c2 * w_j) * w_j
Hi[j1, j1] = (c3 + θi * w_j) * w_j
end

cone.inv_hess_updated = true
Expand All @@ -206,21 +207,23 @@ function inv_hess_prod!(
ζ = cone.ζ
ϕ = cone.ϕ
di = cone.di
η = cone.η
diϕ = ϕ * di
θi = inv(1 + η)
c1 = di * η * θi
c2 = abs2(ζ) + diϕ * ϕ
rw = cone.tempw
ϕdi = ϕ * di
c2 = inv(cone.ϕζidi + 1)
c3 = c2 / ζ * di
c4 = abs2(ζ) + ϕdi * ϕ

@inbounds for j in 1:size(prod, 2)
p = arr[1, j]
@views r = arr[2:end, j]
@. rw = r * w

c5 = sum(rw)
prod[1, j] = ϕdi * c5 + c4 * p
c6 = ϕdi * (c3 * c5 + p)
@. prod[2:end, j] = (c6 + c2 * rw) * w
c3 = dot(w, r)
c4 = diϕ * p + c1 * c3

prod[1, j] = c2 * p + diϕ * c3
@. prod[2:end, j] = (c4 + θi * rw) * w
end

return prod
Expand All @@ -235,27 +238,27 @@ function dder3(cone::HypoGeoMean{T}, dir::AbstractVector{T}) where {T <: Real}
ζ = cone.ζ
di = cone.di
ϕ = cone.ϕ
ϕζidi = cone.ϕζidi
η = cone.η
θ = 1 + η
rwi = cone.tempw

@. rwi = r / w
c0 = sum(rwi) * di
c6 = sum(abs2, rwi) * di
ζiχ = (p - ϕ * c0) / ζ
c1 = ζiχ^2 + ϕ / ζ * (c6 - abs2(c0)) / 2
c7 = ϕζidi * (c1 - c6 / 2 + c0 * (ζiχ + c0 / 2))
c8 = -ϕζidi * (ζiχ + c0)
c9 = ϕζidi + 1

dder3[1] = c1 / -ζ
@. dder3[2:end] = (c7 + rwi * (c8 + c9 * rwi)) / w
tr1 = sum(rwi)
χ = -p / ζ + η * tr1
ητ = η *- di * tr1)
ηυh = η * (sum(abs2, rwi) - di * abs2(tr1)) / 2
c1 = χ * ητ +- di) * ηυh

dder3[1] = (abs2(χ) + ηυh) / -ζ
@. dder3[2:end] = (c1 + rwi * (ητ + θ * rwi)) / w

return dder3
end

function get_central_ray_hypogeomean(::Type{T}, d::Int) where {T <: Real}
c = sqrt(T(d * (5 * d + 2) + 1))
u = -sqrt((-c + 3 * d + 1) / T(2 + 2 * d))
w = -u * (d + 1 + c) / (2 * d)
w = -u * (d + 1 + c) / T(2 * d)
return (u, w)
end
15 changes: 7 additions & 8 deletions src/Cones/hyporootdettri.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ end

function update_feas(cone::HypoRootdetTri{T}) where {T <: Real}
@assert !cone.feas_updated

@views svec_to_smat!(cone.mat, cone.point[2:end], cone.rt2)

fact = cone.fact_W = cholesky!(Hermitian(cone.mat, :U), check = false)
if isposdef(fact)
cone.ϕ = exp(logdet(fact) / cone.d)
cone.ϕ = exp(cone.di * logdet(fact))
cone.ζ = cone.ϕ - cone.point[1]
cone.is_feas = (cone.ζ > eps(T))
else
Expand All @@ -117,7 +117,7 @@ function is_dual_feas(cone::HypoRootdetTri{T}) where {T <: Real}
@views svec_to_smat!(cone.mat2, cone.dual_point[2:end], cone.rt2)
fact = cholesky!(Hermitian(cone.mat2, :U), check = false)
if isposdef(fact)
return (logdet(fact) - cone.d * log(-u / cone.d) > eps(T))
return (logdet(fact) - cone.d * log(-u * cone.di) > eps(T))
end
end

Expand Down Expand Up @@ -193,14 +193,13 @@ function hess_prod!(
χ = c1 - p / ζ
ητ = η * χ - di * c1

prod[1, j] = χ / -ζ
lmul!(θ, w_aux)
for i in diagind(w_aux)
w_aux[i] += ητ
end
rdiv!(w_aux, FU')
ldiv!(FU, w_aux)

prod[1, j] = χ / -ζ
@views smat_to_svec!(prod[2:end, j], w_aux, cone.rt2)
end

Expand Down Expand Up @@ -251,10 +250,10 @@ function inv_hess_prod!(
W = Hermitian(cone.mat4, :U)
ζ = cone.ζ
ϕ = cone.ϕ
η = cone.η
di = cone.di
η = cone.η
diϕ = ϕ * di
θi = inv(1 + cone.η)
θi = inv(1 + η)
c1 = di * η * θi
c2 = abs2(ζ) + diϕ * ϕ
w_aux = cone.mat2
Expand All @@ -268,8 +267,8 @@ function inv_hess_prod!(

c3 = dot(w, r)
c4 = diϕ * p + c1 * c3
prod[1, j] = c2 * p + diϕ * c3

prod[1, j] = c2 * p + diϕ * c3
mul!(w_aux2, w_aux, W)
mul!(w_aux, W, w_aux2)
@views prod_w = prod[2:end, j]
Expand Down

0 comments on commit 6da5ec0

Please sign in to comment.