Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: improve matrix rel entr cone oracles #718

Merged
merged 7 commits into from
May 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Cones/Cones.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using SparseArrays
import Hypatia.RealOrComplex
import Hypatia.chol_inv!
import Hypatia.update_eigen!
import Hypatia.spectral_outer!
import Hypatia.DenseSymCache
import Hypatia.DensePosDefCache
import Hypatia.load_matrix
Expand Down
41 changes: 41 additions & 0 deletions src/Cones/arrayutilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,44 @@ function spectral_kron_element!(
end
return skr
end

# compute a real symmetric Kronecker-like outer product of a real or complex
# matrix of eigenvectors and a real symmetric matrix
function eig_dot_kron!(
skr::AbstractMatrix{T},
inner::Matrix{T},
vecs::Matrix{R},
temp1::Matrix{R},
temp2::Matrix{R},
temp3::Matrix{R},
V::Matrix{R},
rt2::T,
) where {T <: Real, R <: RealOrComplex{T}}
@assert issymmetric(inner) # must be symmetric (wrapper is less efficient)
rt2i = inv(rt2)
d = size(inner, 1)
copyto!(V, vecs') # allows fast column slices
V_views = [view(V, :, i) for i in 1:size(inner, 1)]
scals = (R <: Complex{T} ? (rt2i, rt2i * im) : (rt2i,)) # real and imag parts

col_idx = 1
@inbounds for (j, V_j) in enumerate(V_views)
for i in 1:(j - 1), scal in scals
mul!(temp3, V_j, V_views[i]', scal, false)
@. temp2 = inner * (temp3 + temp3')
mul!(temp1, Hermitian(temp2, :U), V)
mul!(temp2, V', temp1)
@views smat_to_svec!(skr[:, col_idx], temp2, rt2)
col_idx += 1
end

mul!(temp2, V_j, V_j')
temp2 .*= inner
mul!(temp1, Hermitian(temp2, :U), V)
mul!(temp2, V', temp1)
@views smat_to_svec!(skr[:, col_idx], temp2, rt2)
col_idx += 1
end

return skr
end
4 changes: 1 addition & 3 deletions src/Cones/epipersepspectral/epipersepspectral.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ mutable struct EpiPerSepSpectral{Q <: ConeOfSquares, T <: Real} <: Cone{T}
grad_updated::Bool
hess_updated::Bool
inv_hess_updated::Bool
dder3_updated::Bool
hess_aux_updated::Bool
inv_hess_aux_updated::Bool
dder3_aux_updated::Bool
Expand Down Expand Up @@ -69,8 +68,7 @@ end

reset_data(cone::EpiPerSepSpectral) = (cone.feas_updated = cone.grad_updated =
cone.hess_updated = cone.inv_hess_updated = cone.hess_aux_updated =
cone.inv_hess_aux_updated = cone.dder3_updated =
cone.dder3_aux_updated = false)
cone.inv_hess_aux_updated = cone.dder3_aux_updated = false)

use_sqrt_hess_oracles(cone::EpiPerSepSpectral) = false

Expand Down
61 changes: 6 additions & 55 deletions src/Cones/epipersepspectral/matrixcsqr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ function update_hess_aux(cone::EpiPerSepSpectral{<:MatrixCSqr{T}}) where T

ζivi = cache.ζi / cone.point[2]
@. cache.θ = ζivi * Δh + w_λi * w_λi'
copytri!(cache.θ, 'U')

cone.hess_aux_updated = true
end
Expand Down Expand Up @@ -262,7 +263,7 @@ function update_hess(cone::EpiPerSepSpectral{<:MatrixCSqr{T}}) where T

# Hww
@views Hww = H[3:end, 3:end]
eig_kron!(Hww, cache.θ, cone)
eig_dot_kron!(Hww, cache.θ, viw_X, w1, w2, cache.w3, cache.w4, rt2)
mul!(Hww, Hwu, Hwu', true, true)

cone.hess_updated = true
Expand Down Expand Up @@ -356,6 +357,7 @@ function update_inv_hess(cone::EpiPerSepSpectral{<:MatrixCSqr{T}}) where T
isdefined(cone, :inv_hess) || alloc_inv_hess!(cone)
Hi = cone.inv_hess.data
cache = cone.cache
rt2 = cache.rt2
viw_X = cache.viw_X
c4 = cache.c4
wT = cache.wT
Expand All @@ -372,17 +374,17 @@ function update_inv_hess(cone::EpiPerSepSpectral{<:MatrixCSqr{T}}) where T
@views γ_vec = Hi[3:end, 2]
mul!(w2, Diagonal(cache.γ), viw_X')
mul!(w1, viw_X, w2)
smat_to_svec!(γ_vec, w1, cache.rt2)
smat_to_svec!(γ_vec, w1, rt2)
@. Hi[2, 3:end] = c4 * γ_vec
mul!(w2, Diagonal(cache.α), viw_X')
mul!(w1, viw_X, w2)
smat_to_svec!(HiuW, w1, cache.rt2)
smat_to_svec!(HiuW, w1, rt2)
@. HiuW += Hiuv * γ_vec

# Hiww
@views Hiww = Hi[3:end, 3:end]
@. wT = inv(cache.θ)
eig_kron!(Hiww, wT, cone)
eig_dot_kron!(Hiww, wT, viw_X, w1, w2, cache.w3, cache.w4, rt2)
mul!(Hiww, γ_vec, γ_vec', c4, true)

cone.inv_hess_updated = true
Expand Down Expand Up @@ -560,54 +562,3 @@ function dder3(

return dder3
end

function eig_kron!(
Hww::AbstractMatrix{T},
dot_mat::Matrix{T},
cone::EpiPerSepSpectral{<:MatrixCSqr{T}},
) where T
rt2 = sqrt(T(2))
rt2i = inv(rt2)
d = cone.d
cache = cone.cache
w1 = cache.w1
w2 = cache.w2
w3 = cache.w3
V = cache.w4
copyto!(V, cache.viw_X') # allows column slices

col_idx = 1
@inbounds for j in 1:d
@views V_j = V[:, j]
for i in 1:(j - 1)
@views V_i = V[:, i]
mul!(w2, V_j, V_i', rt2i, zero(T))

@. w3 = w2 + w2'
w3 .*= dot_mat
mul!(w1, Hermitian(w3, :U), V)
mul!(w3, V', w1)
@views smat_to_svec!(Hww[:, col_idx], w3, rt2)
col_idx += 1

if cache.is_complex
w2 *= im
@. w3 = w2 + w2'
w3 .*= dot_mat
mul!(w1, Hermitian(w3, :U), V)
mul!(w3, V', w1)
@views smat_to_svec!(Hww[:, col_idx], w3, rt2)
col_idx += 1
end
end

mul!(w3, V_j, V_j')
w3 .*= dot_mat
mul!(w1, Hermitian(w3, :U), V)
mul!(w3, V', w1)
@views smat_to_svec!(Hww[:, col_idx], w3, rt2)
col_idx += 1
end

return Hww
end
Loading