Skip to content

Commit

Permalink
simplify math for grad of logm
Browse files Browse the repository at this point in the history
  • Loading branch information
lkapelevich committed Dec 26, 2020
1 parent 6e6f996 commit cbf6402
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 29 deletions.
24 changes: 12 additions & 12 deletions src/Cones/Cones.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,32 +336,32 @@ vec_copy_to!(v1::AbstractVecOrMat{Complex{T}}, v2::AbstractVecOrMat{T}) where {T
# utilities for hessians for cones with PSD parts

# TODO parallelize
function symm_kron(H::AbstractMatrix{T}, mat::AbstractMatrix{T}, rt2::T) where {T <: Real}
function symm_kron(H::AbstractMatrix{T}, mat::AbstractMatrix{T}, rt2::T; upper_only::Bool = true) where {T <: Real}
side = size(mat, 1)
k = 1
@inbounds for i in 1:side
for j in 1:(i - 1)
@inbounds for j in 1:side
for i in 1:(j - 1)
k2 = 1
for i2 in 1:side
k2 > k && continue
for j2 in 1:(i2 - 1)
for j2 in 1:side
upper_only && k2 > k && continue
for i2 in 1:(j2 - 1)
scal = (i == j ? 1 : rt2) * (i2 == j2 ? 1 : rt2) / 2
H[k2, k] = scal * (mat[i, i2] * mat[j, j2] + mat[i, j2] * mat[j, i2])
k2 += 1
end
H[k2, k] = rt2 * mat[i, i2] * mat[j, i2]
H[k2, k] = rt2 * mat[i, j2] * mat[j, j2]
k2 += 1
end
k += 1
end
k2 = 1
for i2 in 1:side
k2 > k && continue
for j2 in 1:(i2 - 1)
H[k2, k] = rt2 * mat[i, j2] * mat[i, i2]
for j2 in 1:side
upper_only && k2 > k && continue
for i2 in 1:(j2 - 1)
H[k2, k] = rt2 * mat[j, j2] * mat[j, i2]
k2 += 1
end
H[k2, k] = abs2(mat[i, i2])
H[k2, k] = abs2(mat[j, j2])
k2 += 1
end
k += 1
Expand Down
39 changes: 22 additions & 17 deletions src/Cones/epitracerelentropytri.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ mutable struct EpiTraceRelEntropyTri{T <: Real} <: Cone{T}
use_dual::Bool = false,
hess_fact_cache = hessian_cache(T),
) where {T <: Real}
@assert dim > 1
@assert dim > 2
cone = new{T}()
cone.use_dual_barrier = use_dual
cone.dim = dim
Expand Down Expand Up @@ -265,24 +265,29 @@ function update_hess(cone::EpiTraceRelEntropyTri{T}) where {T <: Real}
return cone.hess
end

function grad_logm!(mat, vecs, diff_mat, rt2)
d = size(vecs, 1)
row_idx = 1
for j in 1:d, i in 1:j
col_idx = 1
for l in 1:d, k in 1:l
mat[row_idx, col_idx] += sum(diff_mat[m, n] * (
vecs[i, m] * vecs[k, m] * vecs[l, n] * vecs[j, n] +
vecs[j, m] * vecs[k, m] * vecs[l, n] * vecs[i, n] +
vecs[i, m] * vecs[l, m] * vecs[k, n] * vecs[j, n] +
vecs[j, m] * vecs[l, m] * vecs[k, n] * vecs[i, n]
) * (m == n ? 1 : 2) * (i == j ? 1 : rt2) * (k == l ? 1 : rt2)
for m in 1:d for n in 1:m)
col_idx += 1
function partial_symm_kron(H::AbstractMatrix{T}, mat::AbstractMatrix{T}, rt2::T, p::Int) where T
side = size(mat, 1)
idx1 = 1
for j in 1:side, i in 1:j
idx2 = 1
for l in 1:side, k in 1:l
if p == 1
H[idx1, idx2] = mat[i, k] * mat[j, l] * (i == j ? 1 : rt2) * (k == l ? 1 : rt2)
elseif p == 2
H[idx1, idx2] = mat[i, l] * mat[j, k] * (i == j ? 1 : rt2) * (k == l ? 1 : rt2)
end
idx2 += 1
end
row_idx += 1
idx1 += 1
end
mat ./= 4
return H
end

# TODO optimize
function grad_logm!(mat::Matrix{T}, vecs::Matrix{T}, diff_mat::Hermitian{T, Matrix{T}}, rt2::T) where T
A = symm_kron(similar(mat), vecs, rt2, upper_only = false)
l = smat_to_svec!(zeros(T, size(mat, 1)), diff_mat, one(T))
mat .= A' * Diagonal(l) * A
return mat
end

Expand Down

0 comments on commit cbf6402

Please sign in to comment.