Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor wsos matrix oracles for faster grad #656

Merged
merged 3 commits into from
Mar 24, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 43 additions & 44 deletions src/Cones/wsosinterppossemideftri.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,60 +142,59 @@ end

is_dual_feas(cone::WSOSInterpPosSemidefTri) = true

# diagonal from each (i, j) block in mat' * mat
function block_diag_prod!(vect::Vector{T}, mat::Matrix{T}, U::Int, R::Int, rt2::T, scal::Int = 1) where T
@inbounds for u in 1:U
idx = u
j_idx = u
for j in 1:R
i_idx = u
for i in 1:(j - 1)
@views vect[idx] += dot(mat[:, i_idx], mat[:, j_idx]) * rt2 * scal
idx += U
i_idx += U
end
@views vect[idx] += sum(abs2, mat[:, j_idx]) * scal
j_idx += U
idx += U
end
end
return
end

function update_grad(cone::WSOSInterpPosSemidefTri)
@assert is_feas(cone)
U = cone.U
R = cone.R
cone.grad .= 0

# update PΛiP
@inbounds for k in eachindex(cone.PΛiP)
L = size(cone.Ps[k], 2)
ΛFL = cone.ΛFL[k].L
ΛFLP = cone.ΛFLP[k]

# given cholesky L factor ΛFL, get ΛFLP = ΛFL \ kron(I, P')
@inbounds for p in 1:R
for p in 1:R
block_U_p_idxs = block_idxs(U, p)
block_L_p_idxs = block_idxs(L, p)
@views ΛFLP_pp = ΛFLP[block_L_p_idxs, block_U_p_idxs]
# ΛFLP_pp = ΛFL_pp \ P'
@views ldiv!(ΛFLP_pp, LowerTriangular(ΛFL[block_L_p_idxs, block_L_p_idxs]), cone.Ps[k]')
# to get off-diagonals in ΛFLP, subtract known blocks aggregated in ΛFLP_qp
@inbounds for q in (p + 1):R
for q in (p + 1):R
block_L_q_idxs = block_idxs(L, q)
@views ΛFLP_qp = ΛFLP[block_L_q_idxs, block_U_p_idxs]
ΛFLP_qp .= 0
@inbounds for p2 in p:(q - 1)
for p2 in p:(q - 1)
block_L_p2_idxs = block_idxs(L, p2)
@views mul!(ΛFLP_qp, ΛFL[block_L_q_idxs, block_L_p2_idxs], ΛFLP[block_L_p2_idxs, block_U_p_idxs], -1, 1)
end
@views ldiv!(LowerTriangular(ΛFL[block_L_q_idxs, block_L_q_idxs]), ΛFLP_qp)
end
end

# PΛiP = ΛFLP' * ΛFLP
PΛiPk = cone.PΛiP[k]
@inbounds for p in 1:R, q in p:R
block_p_idxs = block_idxs(U, p)
block_q_idxs = block_idxs(U, q)
# since ΛFLP is block lower triangular rows only from max(p,q) start making a nonzero contribution to the product
row_range = ((q - 1) * L + 1):(L * R)
@inbounds @views mul!(PΛiPk[block_p_idxs, block_q_idxs], ΛFLP[row_range, block_p_idxs]', ΛFLP[row_range, block_q_idxs])
end
LinearAlgebra.copytri!(PΛiPk, 'U')
end

# update gradient
for p in 1:cone.R, q in 1:p
scal = (p == q ? -1 : -cone.rt2)
idx = (svec_idx(p, q) - 1) * U
block_p = (p - 1) * U
block_q = (q - 1) * U
for i in 1:U
block_p_i = block_p + i
block_q_i = block_q + i
@inbounds cone.grad[idx + i] = scal * sum(PΛiPk[block_q_i, block_p_i] for PΛiPk in cone.PΛiP)
end
# update grad
block_diag_prod!(cone.grad, ΛFLP, U, R, cone.rt2, -1)
end

cone.grad_updated = true
Expand All @@ -208,6 +207,22 @@ function update_hess(cone::WSOSInterpPosSemidefTri)
U = cone.U
H = cone.hess.data
H .= 0

@inbounds for k in eachindex(cone.PΛiP)
L = size(cone.Ps[k], 2)
# PΛiP = ΛFLP' * ΛFLP
PΛiPk = cone.PΛiP[k]
ΛFLP = cone.ΛFLP[k]
for p in 1:R, q in p:R
block_p_idxs = block_idxs(U, p)
block_q_idxs = block_idxs(U, q)
# since ΛFLP is block lower triangular rows only from max(p,q) start making a nonzero contribution to the product
row_range = ((q - 1) * L + 1):(L * R)
@views mul!(PΛiPk[block_p_idxs, block_q_idxs], ΛFLP[row_range, block_p_idxs]', ΛFLP[row_range, block_q_idxs])
end
LinearAlgebra.copytri!(PΛiPk, 'U')
end

@inbounds for p in 1:R, q in 1:p
block = svec_idx(p, q)
idxs = block_idxs(U, block)
Expand Down Expand Up @@ -266,23 +281,7 @@ function correction(cone::WSOSInterpPosSemidefTri{T}, primal_dir::AbstractVector
end

big_mat_half = mul!(cone.tempLRUR2[k], ΛFLP_dir, Symmetric(cone.PΛiP[k], :U))
# diagonal from each (i, j) block in big_mat_half' * big_mat_half
for u in 1:U
idx = u
j_idx = u
for j in 1:R
i_idx = u
for i in 1:(j - 1)
@views corr[idx] += dot(big_mat_half[:, i_idx], big_mat_half[:, j_idx]) * cone.rt2
idx += U
i_idx += U
end
@views corr[idx] += sum(abs2, big_mat_half[:, j_idx])
j_idx += U
idx += U
end
end

block_diag_prod!(corr, big_mat_half, U, R, cone.rt2)
end

return corr
Expand Down