Skip to content

Commit

Permalink
Merge pull request #656 from chriscoey/matwsosgrad
Browse files Browse the repository at this point in the history
refactor wsos matrix oracles for faster grad
  • Loading branch information
lkapelevich authored Mar 24, 2021
2 parents 32fd28c + 0e691db commit 79933d9
Showing 1 changed file with 43 additions and 44 deletions.
87 changes: 43 additions & 44 deletions src/Cones/wsosinterppossemideftri.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,60 +142,59 @@ end

is_dual_feas(cone::WSOSInterpPosSemidefTri) = true

# diagonal from each (i, j) block in mat' * mat
function block_diag_prod!(vect::Vector{T}, mat::Matrix{T}, U::Int, R::Int, rt2::T, scal::Int = 1) where T
@inbounds for u in 1:U
idx = u
j_idx = u
for j in 1:R
i_idx = u
for i in 1:(j - 1)
@views vect[idx] += dot(mat[:, i_idx], mat[:, j_idx]) * rt2 * scal
idx += U
i_idx += U
end
@views vect[idx] += sum(abs2, mat[:, j_idx]) * scal
j_idx += U
idx += U
end
end
return
end

function update_grad(cone::WSOSInterpPosSemidefTri)
@assert is_feas(cone)
U = cone.U
R = cone.R
cone.grad .= 0

# update PΛiP
@inbounds for k in eachindex(cone.PΛiP)
L = size(cone.Ps[k], 2)
ΛFL = cone.ΛFL[k].L
ΛFLP = cone.ΛFLP[k]

# given cholesky L factor ΛFL, get ΛFLP = ΛFL \ kron(I, P')
@inbounds for p in 1:R
for p in 1:R
block_U_p_idxs = block_idxs(U, p)
block_L_p_idxs = block_idxs(L, p)
@views ΛFLP_pp = ΛFLP[block_L_p_idxs, block_U_p_idxs]
# ΛFLP_pp = ΛFL_pp \ P'
@views ldiv!(ΛFLP_pp, LowerTriangular(ΛFL[block_L_p_idxs, block_L_p_idxs]), cone.Ps[k]')
# to get off-diagonals in ΛFLP, subtract known blocks aggregated in ΛFLP_qp
@inbounds for q in (p + 1):R
for q in (p + 1):R
block_L_q_idxs = block_idxs(L, q)
@views ΛFLP_qp = ΛFLP[block_L_q_idxs, block_U_p_idxs]
ΛFLP_qp .= 0
@inbounds for p2 in p:(q - 1)
for p2 in p:(q - 1)
block_L_p2_idxs = block_idxs(L, p2)
@views mul!(ΛFLP_qp, ΛFL[block_L_q_idxs, block_L_p2_idxs], ΛFLP[block_L_p2_idxs, block_U_p_idxs], -1, 1)
end
@views ldiv!(LowerTriangular(ΛFL[block_L_q_idxs, block_L_q_idxs]), ΛFLP_qp)
end
end

# PΛiP = ΛFLP' * ΛFLP
PΛiPk = cone.PΛiP[k]
@inbounds for p in 1:R, q in p:R
block_p_idxs = block_idxs(U, p)
block_q_idxs = block_idxs(U, q)
# since ΛFLP is block lower triangular rows only from max(p,q) start making a nonzero contribution to the product
row_range = ((q - 1) * L + 1):(L * R)
@inbounds @views mul!(PΛiPk[block_p_idxs, block_q_idxs], ΛFLP[row_range, block_p_idxs]', ΛFLP[row_range, block_q_idxs])
end
LinearAlgebra.copytri!(PΛiPk, 'U')
end

# update gradient
for p in 1:cone.R, q in 1:p
scal = (p == q ? -1 : -cone.rt2)
idx = (svec_idx(p, q) - 1) * U
block_p = (p - 1) * U
block_q = (q - 1) * U
for i in 1:U
block_p_i = block_p + i
block_q_i = block_q + i
@inbounds cone.grad[idx + i] = scal * sum(PΛiPk[block_q_i, block_p_i] for PΛiPk in cone.PΛiP)
end
# update grad
block_diag_prod!(cone.grad, ΛFLP, U, R, cone.rt2, -1)
end

cone.grad_updated = true
Expand All @@ -208,6 +207,22 @@ function update_hess(cone::WSOSInterpPosSemidefTri)
U = cone.U
H = cone.hess.data
H .= 0

@inbounds for k in eachindex(cone.PΛiP)
L = size(cone.Ps[k], 2)
# PΛiP = ΛFLP' * ΛFLP
PΛiPk = cone.PΛiP[k]
ΛFLP = cone.ΛFLP[k]
for p in 1:R, q in p:R
block_p_idxs = block_idxs(U, p)
block_q_idxs = block_idxs(U, q)
# since ΛFLP is block lower triangular rows only from max(p,q) start making a nonzero contribution to the product
row_range = ((q - 1) * L + 1):(L * R)
@views mul!(PΛiPk[block_p_idxs, block_q_idxs], ΛFLP[row_range, block_p_idxs]', ΛFLP[row_range, block_q_idxs])
end
LinearAlgebra.copytri!(PΛiPk, 'U')
end

@inbounds for p in 1:R, q in 1:p
block = svec_idx(p, q)
idxs = block_idxs(U, block)
Expand Down Expand Up @@ -266,23 +281,7 @@ function correction(cone::WSOSInterpPosSemidefTri{T}, primal_dir::AbstractVector
end

big_mat_half = mul!(cone.tempLRUR2[k], ΛFLP_dir, Symmetric(cone.PΛiP[k], :U))
# diagonal from each (i, j) block in big_mat_half' * big_mat_half
for u in 1:U
idx = u
j_idx = u
for j in 1:R
i_idx = u
for i in 1:(j - 1)
@views corr[idx] += dot(big_mat_half[:, i_idx], big_mat_half[:, j_idx]) * cone.rt2
idx += U
i_idx += U
end
@views corr[idx] += sum(abs2, big_mat_half[:, j_idx])
j_idx += U
idx += U
end
end

block_diag_prod!(corr, big_mat_half, U, R, cone.rt2)
end

return corr
Expand Down

0 comments on commit 79933d9

Please sign in to comment.