Skip to content

Commit

Permalink
simplify inv hess oracles for sepspectral cone (#755)
Browse files Browse the repository at this point in the history
* simplify inv hess oracles for sepspectral cone

* add inbounds

* update matrix version
  • Loading branch information
lkapelevich authored Oct 2, 2021
1 parent 9e14878 commit c7b4420
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 38 deletions.
33 changes: 15 additions & 18 deletions src/Cones/epipersepspectral/matrixcsqr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ mutable struct MatrixCSqrCache{T <: Real, R <: RealOrComplex{T}} <: CSqrCache{T}
w4::Matrix{R}
α::Vector{T}
γ::Vector{T}
c0::T
c1::T
c4::T
c5::T
ζ2β::T

MatrixCSqrCache{T, R}() where {T <: Real, R <: RealOrComplex{T}} = new{T, R}()
end
Expand Down Expand Up @@ -339,15 +339,12 @@ function update_inv_hess_aux(cone::EpiPerSepSpectral{<:MatrixCSqr{T}}) where T
@. γ = wd / diag_θ

ζ2β = abs2(cache.ζ) + dot(∇h, α)
c0 = σ + dot(∇h, γ)
c1 = c0 / ζ2β
@inbounds sum1 = sum((viw_λ[i] + c1 * α[i] - γ[i]) * wd[i] for i in 1:cone.d)
c3 = v^-2 + σ * c1 + sum1
c4 = inv(c3 - c0 * c1)
c5 = ζ2β * c3
cache.c0 = c0
c1 = σ + dot(∇h, γ)
@inbounds c4 = inv(v^-2 + sum((viw_λ[i] - γ[i]) * wd[i] for i in 1:cone.d))

cache.c1 = c1
cache.c4 = c4
cache.c5 = c5
cache.ζ2β = ζ2β

cone.inv_hess_aux_updated = true
end
Expand All @@ -359,14 +356,15 @@ function update_inv_hess(cone::EpiPerSepSpectral{<:MatrixCSqr{T}}) where T
cache = cone.cache
rt2 = cache.rt2
viw_X = cache.viw_X
c1 = cache.c1
c4 = cache.c4
ζ2β = cache.ζ2β
wT = cache.wT
w1 = cache.w1
w2 = cache.w2

# Hiuu, Hiuv, Hivv
Hi[1, 1] = c4 * cache.c5
Hiuv = Hi[1, 2] = c4 * cache.c0
Hi[1, 1] = ζ2β + c1 * c4 * c1
Hiuv = Hi[1, 2] = c4 * c1
Hi[2, 2] = c4

# Hiuw, Hivw
Expand Down Expand Up @@ -402,9 +400,9 @@ function inv_hess_prod!(
viw_X = cache.viw_X
α = cache.α
γ = cache.γ
c0 = cache.c0
c1 = cache.c1
c4 = cache.c4
c5 = cache.c5
ζ2β = cache.ζ2β
r_X = Hermitian(cache.w1, :U)
w2 = cache.w2

Expand All @@ -416,10 +414,9 @@ function inv_hess_prod!(
mul!(r_X.data, viw_X', w2)

qγr = q + sum(γ[i] * r_X[i, i] for i in 1:d)
cu = c4 * (c5 * p + c0 * qγr)
cv = c4 * (c0 * p + qγr)
cv = c4 * (c1 * p + qγr)

prod[1, j] = cu + sum(α[i] * r_X[i, i] for i in 1:d)
prod[1, j] = ζ2β * p + c1 * cv + sum(α[i] * r_X[i, i] for i in 1:d)
prod[2, j] = cv

w_prod = r_X
Expand Down
36 changes: 16 additions & 20 deletions src/Cones/epipersepspectral/vectorcsqr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ mutable struct VectorCSqrCache{T <: Real} <: CSqrCache{T}
m::Vector{T}
α::Vector{T}
γ::Vector{T}
c0::T
c1::T
c4::T
c5::T
ζ2β::T

VectorCSqrCache{T}() where {T <: Real} = new{T}()
end
Expand Down Expand Up @@ -224,15 +224,12 @@ function update_inv_hess_aux(cone::EpiPerSepSpectral{<:VectorCSqr})
@. γ = m * w1

ζ2β = abs2(cache.ζ) + dot(∇h, α)
c0 = σ + dot(∇h, γ)
c1 = c0 / ζ2β
@inbounds sum1 = sum((viw[i] + c1 * α[i] - γ[i]) * w1[i] for i in 1:cone.d)
c3 = v^-2 + σ * c1 + sum1
c4 = inv(c3 - c0 * c1)
c5 = ζ2β * c3
cache.c0 = c0
c1 = σ + dot(∇h, γ)
@inbounds c4 = inv(v^-2 + sum((viw[i] - γ[i]) * w1[i] for i in 1:cone.d))

cache.c1 = c1
cache.c4 = c4
cache.c5 = c5
cache.ζ2β = ζ2β

cone.inv_hess_aux_updated = true
end
Expand All @@ -245,13 +242,13 @@ function update_inv_hess(cone::EpiPerSepSpectral{<:VectorCSqr})
m = cache.m
α = cache.α
γ = cache.γ
c0 = cache.c0
c1 = cache.c1
c4 = cache.c4
c5 = cache.c5
ζ2β = cache.ζ2β

# Hiuu, Hiuv, Hivv
Hi[1, 1] = c4 * c5
Hi[1, 2] = c4 * c0
Hi[1, 1] = ζ2β + c1 * c4 * c1
Hi[1, 2] = c4 * c1
Hi[2, 2] = c4

@inbounds for j in 1:cone.d
Expand All @@ -261,7 +258,7 @@ function update_inv_hess(cone::EpiPerSepSpectral{<:VectorCSqr})
Hivj = Hi[2, j2] = c4 * γ[j]

# Hiuw
Hi[1, j2] = α[j] + c0 * Hivj
Hi[1, j2] = α[j] + c1 * Hivj

# Hiww
for i in 1:j
Expand All @@ -284,20 +281,19 @@ function inv_hess_prod!(
m = cache.m
α = cache.α
γ = cache.γ
c0 = cache.c0
c1 = cache.c1
c4 = cache.c4
c5 = cache.c5
ζ2β = cache.ζ2β

@inbounds for j in 1:size(arr, 2)
p = arr[1, j]
q = arr[2, j]
@views r = arr[3:end, j]

qγr = q + dot(γ, r)
cu = c4 * (c5 * p + c0 * qγr)
cv = c4 * (c0 * p + qγr)
cv = c4 * (c1 * p + qγr)

prod[1, j] = cu + dot(α, r)
prod[1, j] = ζ2β * p + c1 * cv + dot(α, r)
prod[2, j] = cv
@. @views prod[3:end, j] = p * α + cv * γ + m * r
end
Expand Down

0 comments on commit c7b4420

Please sign in to comment.