Skip to content

Commit

Permalink
more support for outofplace tr solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
longemen3000 committed Jan 22, 2024
1 parent 6fd621a commit 6873799
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 76 deletions.
2 changes: 1 addition & 1 deletion src/globalization/trs_solvers/TRS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function (ms::TRSolver)(∇f, H, Δ, p)
x, info = trs(H, ∇f, Δ)
p .= x[:, 1]

m = dot(∇f, p) + dot(p, H * p) / 2
m = dot(∇f, p) + dot(p, H, p) / 2
interior = norm(p, 2) Δ
return (
p = p,
Expand Down
40 changes: 36 additions & 4 deletions src/globalization/trs_solvers/root.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,24 @@
abstract type TRSPSolver end
abstract type NearlyExactTRSP <: TRSPSolver end

trs_supports_outofplace(trs) = false

function trs_outofplace_check(trs,prob)
if !trs_supports_outofplace(trs)
throw(
ErrorException("solve() not defined for OutOfPlace() with $(typeof(trs).name.wrapper) for $(typeof(prob).name.wrapper)"),
)
end
end

include("solvers/NWI.jl")
include("solvers/Dogleg.jl")
include("solvers/NTR.jl")
include("solvers/TCG.jl")
#include("subproblemsolvers/TRS.jl") just make an example instead of relying onTRS.jl

function tr_return(; λ, ∇f, H, s, interior, solved, hard_case, Δ, m = nothing)
m = m isa Nothing ? dot(∇f, s) + dot(s, H * s) / 2 : m
m = m isa Nothing ? dot(∇f, s) + dot(s, H, s) / 2 : m
(
p = s,
mz = m,
Expand All @@ -23,13 +33,35 @@ function tr_return(; λ, ∇f, H, s, interior, solved, hard_case, Δ, m = nothin
)
end

function update_H!(H, h, λ = nothing)
update_H!(mstyle::OutOfPlace,H, h, λ) = _update_H(H, h, λ)
update_H!(mstyle::OutOfPlace,H, h) = _update_H(H, h, nothing)
update_H!(mstyle::InPlace,H, h, λ) = _update_H!(H, h, λ)
update_H!(mstyle::InPlace,H, h) = _update_H!(H, h, nothing)

function _update_H!(H, h, λ)
T = eltype(h)
n = length(h)
if !(λ == T(0))
if λ == nothing
for i = 1:n
@inbounds H[i, i] = λ isa Nothing ? h[i] : h[i] + λ
@inbounds H[i, i] = h[i]
end
elseif !== T(0))
for i = 1:n
@inbounds H[i, i] = h[i] + λ
end
end
H
end

function _update_H(H, h, λ = nothing)
T = eltype(h)
if λ == nothing
Hd = Diagonal(h)
return H + Hd
elseif !== T(0))
Hd = Diagonal(h)
return H + Hd + λ*I
else
return H
end
end
4 changes: 3 additions & 1 deletion src/globalization/trs_solvers/solvers/Dogleg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ struct Dogleg{T} <: TRSPSolver
end
Dogleg() = Dogleg(nothing)

trs_supports_outofplace(trs::Dogleg) = true

function (dogleg::Dogleg)(∇f, H, Δ, p, scheme, mstyle; abstol = 1e-10, maxiter = 50)
T = eltype(p)
n = length(∇f)
Expand Down Expand Up @@ -80,7 +82,7 @@ function (dogleg::Dogleg)(∇f, H, Δ, p, scheme, mstyle; abstol = 1e-10, maxite
interior = false
end
end
m = dot(∇f, p) + dot(p, H * p) / 2
m = dot(∇f, p) + dot(p, H, p) / 2

return (
p = p,
Expand Down
65 changes: 41 additions & 24 deletions src/globalization/trs_solvers/solvers/NTR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,21 @@ function (ms::NTR)(
n = length(∇f)
h = H isa UniformScaling ? copy(∇f) .* 0 .+ 1 : diag(H)
H = H isa UniformScaling ? Diagonal(copy(∇f) .* 0 .+ 1) : H

inplace = mstyle == InPlace()
# Check for interior convergence
if λ == T(0)
F = cholesky(Symmetric(H); check = false)
s .= -∇f
s .= F \ s
if inplace
s .= -∇f
s .= F \ s
else
s = -∇f
s = F \ s
end
s₂ = norm(s, 2)

if issuccess(F) && s₂ < Δ
H = update_H!(H, h)
H = update_H!(mstyle, H, h)
return tr_return(;
λ = λ,
∇f = ∇f,
Expand All @@ -94,7 +99,7 @@ function (ms::NTR)(
λL, λU = isg.L, isg.U

for iter = 1:maxiter
H = update_H!(H, h, λ)
H = update_H!(mstyle, H, h, λ)
F = cholesky(Symmetric(H); check = false)
in𝓖, linpack = false, false
#===========================================================================
Expand All @@ -109,13 +114,17 @@ function (ms::NTR)(
# Algorithm 7.3.1 on p. 185 in [ConnGouldTointBook]
# Step 1 was factorizing
# Step 2
s .= -∇f
s .= F \ s

if inplace
s .= -∇f
s .= F \ s
else
s = -∇f
s = F \ s
end
# Check if step is approximately equal to the radius
s₂ = norm(s, 2)
if s₂ Δ
H = update_H!(H, h)
H = update_H!(mstyle, H, h)
return tr_return(;
λ = λ,
∇f = ∇f,
Expand Down Expand Up @@ -145,15 +154,18 @@ function (ms::NTR)(
if in𝓖
linpack = true
w, u = λL_with_linpack(F)
λL = max(λL, λ - dot(u, H * u))
λL = max(λL, λ - dot(u, H, u))

α, s_g, m_g = 𝓖_root(u, s, Δ, ∇f, H)
s .= s_g

if inplace
s .= s_g
else
s = s_g
end
s₂ = norm(s)
# check hard case convergnce
if α^2 * dot(u, H * u) κhard * (dot(s, H * s) + λ * Δ^2)
H = update_H!(H, h)
if α^2 * dot(u, H, u) κhard * (dot(s, H, s) + λ * Δ^2)
H = update_H!(mstyle, H, h)
return tr_return(;
λ = λ,
∇f = ∇f,
Expand All @@ -167,20 +179,21 @@ function (ms::NTR)(
)
end
# If not the hard case solution, try to factorize H(λ⁺)
H = update_H!(H, h, λ⁺)
H = update_H!(mstyle, H, h, λ⁺)
F = cholesky(H; check = false)
if issuccess(F) # Then we're in L, great! lemma 7.3.2
λ = λ⁺
else # we landed in N, this is bad, so use bounds to approach L
λ = max(sqrt(λL * λU), λL + θ * (λU - λL))
λLλU = abs(λL * λU)
λ = max(sqrt(λLλU), λL + θ * (λU - λL))
end
else # in L, we can safely step
λ = λ⁺
end

# check for convergence
if in𝓖 && abs(s₂ - Δ) κeasy * Δ
H = update_H!(H, h)
H = update_H!(mstyle, H, h)
return tr_return(;
λ = λ,
∇f = ∇f,
Expand All @@ -194,9 +207,13 @@ function (ms::NTR)(
elseif abs(s₂ - Δ) κeasy * Δ # implicitly "if in 𝓕" since we're in that branch
# u and α comes from linpack
if linpack
if α^2 * dot(u, H * u) κhard * (dot(sλ, H * sλ) * Δ^2)
s .= s .+ α * u
H = update_H!(H, h)
if α^2 * dot(u, H, u) κhard * (dot(sλ, H, sλ) * Δ^2)
if inplace
s .= s .+ α * u
else
s = s + α * u
end
H = update_H!(mstyle, H, h)
return tr_return(;
λ = λ,
∇f = ∇f,
Expand All @@ -216,10 +233,10 @@ function (ms::NTR)(
# lower bound, we cannot apply the Newton step here.
δ, v = λL_in_𝓝(H, F)
λL = max(λL, λ + δ / dot(v, v)) # update lower bound
λ = max(sqrt(λL * λU), λL + θ * (λU - λL)) # no converence possible, so step in bracket
λ = max(sqrt(λLλU), λL + θ * (λU - λL)) # no convergence possible, so step in bracket
end
end
H = update_H!(H, h)
H = update_H!(mstyle, H, h)
tr_return(;
λ = λ,
∇f = ∇f,
Expand Down Expand Up @@ -272,9 +289,9 @@ function 𝓖_root(u, s, Δ, ∇f, H)
α₂ = (-pb - pd) / 2pa

s₁ = s + α₁ * u
m₁ = dot(∇f, s₁) + dot(s₁, H * s₁) / 2
m₁ = dot(∇f, s₁) + dot(s₁, H, s₁) / 2
s₂ = s + α₂ * u
m₂ = dot(∇f, s₂) + dot(s₂, H * s₂) / 2
m₂ = dot(∇f, s₂) + dot(s₂, H, s₂) / 2
α, s, m = m₁ m₂ ? (α₁, s₁, m₁) : (α₂, s₂, m₂)
α, s, m
end
49 changes: 32 additions & 17 deletions src/globalization/trs_solvers/solvers/NWI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ struct NWI{T} <: NearlyExactTRSP
end
NWI() = NWI(eigen)
summary(::NWI) = "Trust Region (Newton, eigen)"

trs_supports_outofplace(trs::NWI) = true
"""
initial_safeguards(B, h, g, Δ)
Expand Down Expand Up @@ -112,19 +112,27 @@ function is_maybe_hard_case(QΛQ, Qt∇f::AbstractVector{T}) where {T}
end

# Equation 4.38 in N&W (2006)
calc_p!(p, Qt∇f, QΛQ, λ) = calc_p!(p, Qt∇f, QΛQ, λ, 1)
calc_p!(mstyle::MutateStyle, p, Qt∇f, QΛQ, λ) = calc_p!(mstyle, p, Qt∇f, QΛQ, λ, 1)

# Equation 4.45 in N&W (2006) since we allow for first_j > 1
function calc_p!(p, Qt∇f, QΛQ, λ::T, first_j) where {T}
function calc_p!(mstyle::MutateStyle, p, Qt∇f, QΛQ, λ::T, first_j) where {T}
inplace = mstyle === InPlace()
# Reset search direction to 0
fill!(p, T(0))

if inplace
fill!(p, T(0))
else
p = T(0) .* p
end
# Unpack eigenvalues and eigenvectors
Λ = QΛQ.values
Q = QΛQ.vectors
for j = first_j:length(Λ)
κ = Qt∇f[j] / (Λ[j] + λ)
@. p = p - κ * Q[:, j]
if inplace
@. p = p - κ * Q[:, j]
else
p = p - κ * Q[:, j]
end
end
p
end
Expand Down Expand Up @@ -153,7 +161,7 @@ function (ms::NWI)(∇f, H, Δ, p, scheme, mstyle; abstol = 1e-10, maxiter = 50)
n = length(∇f)
H = H isa UniformScaling ? Diagonal(copy(∇f) .* 0 .+ 1) : H
h = diag(H)

inplace = mstyle == InPlace()
# Note that currently the eigenvalues are only sorted if H is perfectly
# symmetric. (Julia issue #17093)
if H isa Diagonal
Expand All @@ -176,14 +184,14 @@ function (ms::NWI)(∇f, H, Δ, p, scheme, mstyle; abstol = 1e-10, maxiter = 50)
# positive, so the Newton step, pN, is fine unless norm(pN, 2) > Δ.
if λmin >= sqrt(eps(T))
λ = T(0) # no amount of I is added yet
p = calc_p!(p, Qt∇f, QΛQ, λ) # calculate the Newton step
p = calc_p!(mstyle, p, Qt∇f, QΛQ, λ) # calculate the Newton step
if norm(p, 2) Δ
# No shrinkage is necessary: -(H \ ∇f) is the minimizer
interior = true
solved = true
hard_case = false

m = dot(∇f, p) + dot(p, H * p) / 2
m = dot(∇f, p) + dot(p, H, p) / 2

return (
p = p,
Expand Down Expand Up @@ -218,7 +226,7 @@ function (ms::NWI)(∇f, H, Δ, p, scheme, mstyle; abstol = 1e-10, maxiter = 50)

# The old p is discarded, and replaced with one that takes into account
# the first j such that λj ≠ λmin. Formula 4.45 in N&W (2006)
= calc_p!(p, Qt∇f, QΛQ, λ, first_j)
= calc_p!(mstyle, p, Qt∇f, QΛQ, λ, first_j)

# Check if the choice of λ leads to a solution inside the trust region.
# If it does, then we construct the "hard case solution".
Expand All @@ -228,9 +236,12 @@ function (ms::NWI)(∇f, H, Δ, p, scheme, mstyle; abstol = 1e-10, maxiter = 50)

tau = sqrt^2 - norm(pλ, 2)^2)

@. p = -+ tau * Q[:, 1]

m = dot(∇f, p) + dot(p, H * p) / 2
if inplace
@. p = -+ tau * Q[:, 1]
else
p = -+ tau * Q[:, 1]
end
m = dot(∇f, p) + dot(p, H, p) / 2

return (
p = p,
Expand All @@ -257,7 +268,7 @@ function (ms::NWI)(∇f, H, Δ, p, scheme, mstyle; abstol = 1e-10, maxiter = 50)
λ = safeguard_λ(λ, isg)
for iter = 1:maxiter
λ_previous = λ
H = update_H!(H, h, λ)
H = update_H!(mstyle, H, h, λ)

F =
H isa Diagonal ? cholesky(H; check = false) :
Expand All @@ -271,7 +282,11 @@ function (ms::NWI)(∇f, H, Δ, p, scheme, mstyle; abstol = 1e-10, maxiter = 50)
continue
end
R = F.U
p .= R \ (R' \ -∇f)
if inplace
p .= R \ (R' \ -∇f)
else
p = R \ (R' \ -∇f)
end
q_l = R' \ p

p_norm = norm(p, 2)
Expand All @@ -289,8 +304,8 @@ function (ms::NWI)(∇f, H, Δ, p, scheme, mstyle; abstol = 1e-10, maxiter = 50)
end
end

H = update_H!(H, h)
m = dot(∇f, p) + dot(p, H * p) / 2
H = update_H!(mstyle, H, h)
m = dot(∇f, p) + dot(p, H, p) / 2
return (
p = p,
mz = m,
Expand Down
Loading

0 comments on commit 6873799

Please sign in to comment.