Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Halley's method via descent API #404

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why this comes in, I think that got in when I dev the libs. can be removed

SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
Expand Down Expand Up @@ -113,6 +114,7 @@ StaticArrays = "1.9"
StaticArraysCore = "1.4"
Sundials = "4.23.1"
SymbolicIndexingInterface = "0.3.31"
TaylorDiff = "0.3"
Test = "1.10"
Zygote = "0.6.69"
julia = "1.10"
Expand Down Expand Up @@ -146,8 +148,9 @@ SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote"]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "TaylorDiff", "Test", "Zygote"]
3 changes: 3 additions & 0 deletions lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"

[extensions]
NonlinearSolveBaseBandedMatricesExt = "BandedMatrices"
Expand All @@ -44,6 +45,7 @@ NonlinearSolveBaseLineSearchExt = "LineSearch"
NonlinearSolveBaseLinearSolveExt = "LinearSolve"
NonlinearSolveBaseSparseArraysExt = "SparseArrays"
NonlinearSolveBaseSparseMatrixColoringsExt = "SparseMatrixColorings"
NonlinearSolveBaseTaylorDiffExt = "TaylorDiff"

[compat]
ADTypes = "1.9"
Expand Down Expand Up @@ -77,6 +79,7 @@ SparseArrays = "1.10"
SparseMatrixColorings = "0.4.5"
StaticArraysCore = "1.4"
SymbolicIndexingInterface = "0.3.31"
TaylorDiff = "0.3"
Test = "1.10"
TimerOutputs = "0.5.23"
julia = "1.10"
Expand Down
20 changes: 20 additions & 0 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseTaylorDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module NonlinearSolveBaseTaylorDiffExt
using SciMLBase: NonlinearFunction
using NonlinearSolveBase: HalleyDescentCache
import NonlinearSolveBase: evaluate_hvvp
using TaylorDiff: derivative, derivative!
using FastClosures: @closure

function evaluate_hvvp(
Comment on lines +3 to +8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using NonlinearSolveBase: HalleyDescentCache
import NonlinearSolveBase: evaluate_hvvp
using TaylorDiff: derivative, derivative!
using FastClosures: @closure
function evaluate_hvvp(
using NonlinearSolveBase: NonlinearSolveBase, HalleyDescentCache
using TaylorDiff: derivative, derivative!
using FastClosures: @closure
function NonlinearSolveBase.evaluate_hvvp(

style nit

hvvp, cache::HalleyDescentCache, f::NonlinearFunction{iip}, p, u, δu) where {iip}
if iip
binary_f = @closure (y, x) -> f(y, x, p)
derivative!(hvvp, binary_f, cache.fu, u, δu, Val(2))
else
unary_f = Base.Fix2(f, p)
hvvp = derivative(unary_f, u, δu, Val(2))
end
hvvp
end

end
1 change: 1 addition & 0 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ include("polyalg.jl")

include("descent/common.jl")
include("descent/newton.jl")
include("descent/halley.jl")
include("descent/steepest.jl")
include("descent/damped_newton.jl")
include("descent/dogleg.jl")
Expand Down
100 changes: 100 additions & 0 deletions lib/NonlinearSolveBase/src/descent/halley.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
HalleyDescent(; linsolve = nothing)

Improve the NewtonDescent with higher-order terms. First compute the descent direction as ``J a = -fu``.
Then compute the hessian-vector-vector product and solve for the second-order correction term as ``J b = H a a``.
Finally, compute the descent direction as ``δu = a * a / (b / 2 - a)``.

Note that `import TaylorDiff` is required to use this descent algorithm.

See also [`NewtonDescent`](@ref).
"""
@kwdef @concrete struct HalleyDescent <: AbstractDescentDirection
linsolve = nothing
end

supports_line_search(::HalleyDescent) = true

@concrete mutable struct HalleyDescentCache <: AbstractDescentCache
f
p
δu
δus
b
fu
hvvp
lincache
timer
preinverted_jacobian <: Union{Val{false}, Val{true}}
end

@internal_caches HalleyDescentCache :lincache

function InternalAPI.init(
prob::NonlinearProblem, alg::HalleyDescent, J, fu, u; stats,
shared = Val(1), pre_inverted::Val = Val(false),
linsolve_kwargs = (;), abstol = nothing, reltol = nothing,
timer = get_timer_output(), kwargs...)
@bb δu = similar(u)
@bb b = similar(u)
@bb fu = similar(fu)
@bb hvvp = similar(fu)
δus = Utils.unwrap_val(shared) ≤ 1 ? nothing : map(2:Utils.unwrap_val(shared)) do i
@bb δu_ = similar(u)
end
lincache = Utils.unwrap_val(pre_inverted) ? nothing :
construct_linear_solver(
alg, alg.linsolve, J, Utils.safe_vec(fu), Utils.safe_vec(u);
stats, abstol, reltol, linsolve_kwargs...
)
return HalleyDescentCache(
prob.f, prob.p, δu, δus, b, fu, hvvp, lincache, timer, pre_inverted)
end

function InternalAPI.solve!(
cache::HalleyDescentCache, J, fu, u, idx::Val = Val(1);
skip_solve::Bool = false, new_jacobian::Bool = true, kwargs...)
δu = SciMLBase.get_du(cache, idx)
skip_solve && return DescentResult(; δu)
if preinverted_jacobian(cache)
@assert J!==nothing "`J` must be provided when `pre_inverted = Val(true)`."
@bb δu = J × vec(fu)
else
@static_timeit cache.timer "linear solve 1" begin
linres = cache.lincache(;
A = J, b = Utils.safe_vec(fu),
kwargs..., linu = Utils.safe_vec(δu),
reuse_A_if_factorization = !new_jacobian || (idx !== Val(1)))
δu = Utils.restructure(SciMLBase.get_du(cache, idx), linres.u)
if !linres.success
set_du!(cache, δu, idx)
return DescentResult(; δu, success = false, linsolve_success = false)
end
end
end
b = cache.b
# compute the hessian-vector-vector product
hvvp = evaluate_hvvp(cache.hvvp, cache, cache.f, cache.p, u, δu)
# second linear solve, reuse factorization if possible
if preinverted_jacobian(cache)
@bb b = J × vec(hvvp)
else
@static_timeit cache.timer "linear solve 2" begin
linres = cache.lincache(;
A = J, b = Utils.safe_vec(hvvp),
kwargs..., linu = Utils.safe_vec(b),
reuse_A_if_factorization = true)
b = Utils.restructure(cache.b, linres.u)
if !linres.success
set_du!(cache, δu, idx)
return DescentResult(; δu, success = false, linsolve_success = false)
end
end
end
@bb @. δu = δu * δu / (b / 2 - δu)
set_du!(cache, δu, idx)
cache.b = b
return DescentResult(; δu)
end

evaluate_hvvp(hvvp, cache, f, p, u, δu) = error("not implemented. please import TaylorDiff")
7 changes: 4 additions & 3 deletions lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm,
AbstractTrustRegionMethodCache,
Utils, InternalAPI, get_timer_output, @static_timeit,
update_trace!, L2_NORM,
NewtonDescent, DampedNewtonDescent, GeodesicAcceleration,
Dogleg
NewtonDescent, DampedNewtonDescent, HalleyDescent,
GeodesicAcceleration, Dogleg
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode,
NonlinearFunction,
NonlinearLeastSquaresProblem, NonlinearProblem, NoSpecialize
Expand All @@ -31,6 +31,7 @@ using FiniteDiff: FiniteDiff # Default Finite Difference Method
using ForwardDiff: ForwardDiff # Default Forward Mode AD

include("raphson.jl")
include("halley.jl")
include("gauss_newton.jl")
include("levenberg_marquardt.jl")
include("trust_region.jl")
Expand Down Expand Up @@ -93,7 +94,7 @@ end

@reexport using SciMLBase, NonlinearSolveBase

export NewtonRaphson, PseudoTransient
export NewtonRaphson, Halley, PseudoTransient
export GaussNewton, LevenbergMarquardt, TrustRegion

export RadiusUpdateSchemes
Expand Down
15 changes: 15 additions & 0 deletions lib/NonlinearSolveFirstOrder/src/halley.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Halley(; concrete_jac = nothing, linsolve = nothing, linesearch = missing,
autodiff = nothing)

An experimental Halley's method implementation. Improves the convergence rate of Newton's method by using second-order derivative information to correct the descent direction.

Currently depends on TaylorDiff.jl to handle the correction terms,
might have more general implementation in the future.
"""
function Halley(; concrete_jac = nothing, linsolve = nothing,
linesearch = missing, autodiff = nothing)
return GeneralizedFirstOrderAlgorithm(;
concrete_jac, name = :Halley, linesearch,
descent = HalleyDescent(; linsolve), autodiff)
end
Comment on lines +1 to +15
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really a First Order method, we might want an additional split cc @ChrisRackauckas

9 changes: 7 additions & 2 deletions test/23_test_problems_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
@testsetup module RobustnessTesting
using NonlinearSolve, LinearAlgebra, LinearSolve, NonlinearProblemLibrary, Test
import TaylorDiff

problems = NonlinearProblemLibrary.problems
dicts = NonlinearProblemLibrary.dicts
Expand Down Expand Up @@ -61,10 +62,14 @@ end
end

@testitem "23 Test Problems: Halley" setup=[RobustnessTesting] tags=[:core] begin
alg_ops = (SimpleHalley(; autodiff = AutoForwardDiff()),)
alg_ops = (
Halley(),
SimpleHalley(; autodiff = AutoForwardDiff())
)

broken_tests = Dict(alg => Int[] for alg in alg_ops)
broken_tests[alg_ops[1]] = [1, 5, 15, 16, 18]
broken_tests[alg_ops[1]] = [1, 5, 15, 16]
broken_tests[alg_ops[2]] = [1, 5, 15, 16, 18]

test_on_library(problems, dicts, alg_ops, broken_tests)
end
Expand Down
Loading