From ce721be9a7c2886c9daa58f23699dd33bbe8ad9f Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Tue, 3 Dec 2024 00:46:35 +0800 Subject: [PATCH 1/9] refactor: Move dual nonlinear solving to NonlinearSolveBase --- .../ext/NonlinearSolveBaseForwardDiffExt.jl | 111 +++++++++++++++++- lib/NonlinearSolveFirstOrder/Project.toml | 3 +- .../test/misc_tests.jl | 10 ++ src/NonlinearSolve.jl | 2 - src/forward_diff.jl | 99 ---------------- 5 files changed, 121 insertions(+), 104 deletions(-) delete mode 100644 src/forward_diff.jl diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl index bb3165396..6357549ec 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl @@ -2,17 +2,36 @@ module NonlinearSolveBaseForwardDiffExt using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff using ArrayInterface: ArrayInterface -using CommonSolve: solve +using CommonSolve: CommonSolve, solve +using ConcreteStructs: @concrete using DifferentiationInterface: DifferentiationInterface using FastClosures: @closure using ForwardDiff: ForwardDiff, Dual using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem, remake -using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils +using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, + AbstractNonlinearSolveAlgorithm, Utils, InternalAPI, + AbstractNonlinearSolveCache const DI = DifferentiationInterface +const ALL_SOLVER_TYPES = [ + Nothing, AbstractNonlinearSolveAlgorithm +] + +const DualNonlinearProblem = NonlinearProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualAbstractNonlinearProblem = Union{ + DualNonlinearProblem, DualNonlinearLeastSquaresProblem +} + function NonlinearSolveBase.additional_incompatible_backend_check( prob::AbstractNonlinearProblem, ::Union{AutoForwardDiff, AutoPolyesterForwardDiff}) return !ForwardDiff.can_dual(eltype(prob.u0)) @@ -102,4 +121,92 @@ function NonlinearSolveBase.nonlinearsolve_dual_solution( return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, Utils.restructure(u, partials))) end +for algType in ALL_SOLVER_TYPES + @eval function SciMLBase.__solve( + prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... + ) + sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( + prob, alg, args...; kwargs... + ) + dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original + ) + end +end + +@concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache + cache + prob + alg + p + values_p + partials_p +end + +function InternalAPI.reinit!( + cache::NonlinearSolveForwardDiffCache, args...; + p = cache.p, u0 = NonlinearSolveBase.get_u(cache.cache), kwargs... +) + InternalAPI.reinit!( + cache.cache; p = nodual_value(p), u0 = nodual_value(u0), kwargs... + ) + cache.p = p + cache.values_p = nodual_value(p) + cache.partials_p = ForwardDiff.partials(p) + return cache +end + +for algType in ALL_SOLVER_TYPES + @eval function SciMLBase.__init( + prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... + ) + p = nodual_value(prob.p) + newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p) + cache = init(newprob, alg, args...; kwargs...) + return NonlinearSolveForwardDiffCache( + cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p) + ) + end +end + +function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache) + sol = solve!(cache.cache) + prob = cache.prob + uu = sol.u + + fn = prob isa NonlinearLeastSquaresProblem ? + NonlinearSolveBase.nlls_generate_vjp_function(prob, sol, uu) : prob.f + + Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, fn, uu, cache.values_p) + Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, fn, uu, cache.values_p) + + z_arr = -Jᵤ \ Jₚ + + sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z) + if cache.p isa Number + partials = sumfun((z_arr, cache.p)) + else + partials = sum(sumfun, zip(eachcol(z_arr), cache.p)) + end + + dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, cache.p) + return SciMLBase.build_solution( + prob, cache.alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original + ) +end + +nodual_value(x) = x +nodual_value(x::Dual) = ForwardDiff.value(x) +nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) + +""" + pickchunksize(x) = pickchunksize(length(x)) + pickchunksize(x::Int) + +Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length. +""" +@inline pickchunksize(x) = pickchunksize(length(x)) +@inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x) + end diff --git a/lib/NonlinearSolveFirstOrder/Project.toml b/lib/NonlinearSolveFirstOrder/Project.toml index ee2d2c9de..c299b6dc1 100644 --- a/lib/NonlinearSolveFirstOrder/Project.toml +++ b/lib/NonlinearSolveFirstOrder/Project.toml @@ -67,6 +67,7 @@ julia = "1.10" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" @@ -86,4 +87,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "BandedMatrices", "BenchmarkTools", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"] +test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ForwardDiff", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"] diff --git a/lib/NonlinearSolveFirstOrder/test/misc_tests.jl b/lib/NonlinearSolveFirstOrder/test/misc_tests.jl index 40fcb2c55..79c63f37c 100644 --- a/lib/NonlinearSolveFirstOrder/test/misc_tests.jl +++ b/lib/NonlinearSolveFirstOrder/test/misc_tests.jl @@ -20,3 +20,13 @@ @test sol.retcode == ReturnCode.Success @test jac_calls == 0 end + +@testitem "Dual of BigFloat: Issue #512" tags=[:core] begin + using NonlinearSolveFirstOrder, ForwardDiff + fn_iip = NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p) + u2 = [ForwardDiff.Dual(BigFloat(1.0), 5.0), ForwardDiff.Dual(BigFloat(1.0), 5.0), + ForwardDiff.Dual(BigFloat(1.0), 5.0)] + prob_iip_bf = NonlinearProblem{true}(fn_iip, u2, ForwardDiff.Dual(BigFloat(2.0), 5.0)) + sol = solve(prob_iip_bf, NewtonRaphson()) + @test sol.retcode == ReturnCode.Success +end diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 4c44cc972..c6fcc1f12 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -62,8 +62,6 @@ const ALL_SOLVER_TYPES = [ NonlinearSolvePolyAlgorithm ] -include("forward_diff.jl") - @setup_workload begin nonlinear_functions = ( (NonlinearFunction{false, NoSpecialize}((u, p) -> u .* u .- p), 0.1), diff --git a/src/forward_diff.jl b/src/forward_diff.jl deleted file mode 100644 index 5bb98561c..000000000 --- a/src/forward_diff.jl +++ /dev/null @@ -1,99 +0,0 @@ -const DualNonlinearProblem = NonlinearProblem{ - <:Union{Number, <:AbstractArray}, iip, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} -} where {iip, T, V, P} -const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{ - <:Union{Number, <:AbstractArray}, iip, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} -} where {iip, T, V, P} -const DualAbstractNonlinearProblem = Union{ - DualNonlinearProblem, DualNonlinearLeastSquaresProblem -} - -for algType in ALL_SOLVER_TYPES - @eval function SciMLBase.__solve( - prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... - ) - sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( - prob, alg, args...; kwargs... - ) - dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p) - return SciMLBase.build_solution( - prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original - ) - end -end - -@concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache - cache - prob - alg - p - values_p - partials_p -end - -function InternalAPI.reinit!( - cache::NonlinearSolveForwardDiffCache, args...; - p = cache.p, u0 = NonlinearSolveBase.get_u(cache.cache), kwargs... -) - InternalAPI.reinit!( - cache.cache; p = nodual_value(p), u0 = nodual_value(u0), kwargs... - ) - cache.p = p - cache.values_p = nodual_value(p) - cache.partials_p = ForwardDiff.partials(p) - return cache -end - -for algType in ALL_SOLVER_TYPES - @eval function SciMLBase.__init( - prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... - ) - p = nodual_value(prob.p) - newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p) - cache = init(newprob, alg, args...; kwargs...) - return NonlinearSolveForwardDiffCache( - cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p) - ) - end -end - -function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache) - sol = solve!(cache.cache) - prob = cache.prob - uu = sol.u - - fn = prob isa NonlinearLeastSquaresProblem ? - NonlinearSolveBase.nlls_generate_vjp_function(prob, sol, uu) : prob.f - - Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, fn, uu, cache.values_p) - Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, fn, uu, cache.values_p) - - z_arr = -Jᵤ \ Jₚ - - sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z) - if cache.p isa Number - partials = sumfun((z_arr, cache.p)) - else - partials = sum(sumfun, zip(eachcol(z_arr), cache.p)) - end - - dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, cache.p) - return SciMLBase.build_solution( - prob, cache.alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original - ) -end - -nodual_value(x) = x -nodual_value(x::Dual) = ForwardDiff.value(x) -nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) - -""" - pickchunksize(x) = pickchunksize(length(x)) - pickchunksize(x::Int) - -Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length. -""" -@inline pickchunksize(x) = pickchunksize(length(x)) -@inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x) From de6eb96517d6d584fc3d4e5f4024f1c33f714a0a Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Wed, 4 Dec 2024 00:12:35 +0800 Subject: [PATCH 2/9] Put DualAbstractNonlinearProblem solving in subpackages --- .../ext/NonlinearSolveBaseForwardDiffExt.jl | 19 ++------ .../src/NonlinearSolveBase.jl | 2 + lib/NonlinearSolveBase/src/common_defaults.jl | 9 ++++ .../src/NonlinearSolveFirstOrder.jl | 4 +- .../src/forward_diff.jl | 34 ++++++++++++++ lib/NonlinearSolveQuasiNewton/Project.toml | 6 +++ ...NonlinearSolveQuasiNewtonForwardDiffExt.jl | 47 +++++++++++++++++++ .../Project.toml | 6 +++ ...inearSolveSpectralMethodsForwardDiffExt.jl | 47 +++++++++++++++++++ src/NonlinearSolve.jl | 11 +---- src/forward_diff.jl | 44 +++++++++++++++++ 11 files changed, 205 insertions(+), 24 deletions(-) create mode 100644 lib/NonlinearSolveFirstOrder/src/forward_diff.jl create mode 100644 lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl create mode 100644 lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl create mode 100644 src/forward_diff.jl diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl index 6357549ec..203d06f14 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl @@ -12,12 +12,12 @@ using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, AbstractNonlinearSolveAlgorithm, Utils, InternalAPI, - AbstractNonlinearSolveCache + AbstractNonlinearSolveCache, NonlinearSolvePolyAlgorithm const DI = DifferentiationInterface -const ALL_SOLVER_TYPES = [ - Nothing, AbstractNonlinearSolveAlgorithm +const GENERAL_SOLVER_TYPES = [ + Nothing, AbstractNonlinearSolveAlgorithm, NonlinearSolvePolyAlgorithm ] const DualNonlinearProblem = NonlinearProblem{ @@ -121,7 +121,7 @@ function NonlinearSolveBase.nonlinearsolve_dual_solution( return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, Utils.restructure(u, partials))) end -for algType in ALL_SOLVER_TYPES +for algType in GENERAL_SOLVER_TYPES @eval function SciMLBase.__solve( prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... ) @@ -157,7 +157,7 @@ function InternalAPI.reinit!( return cache end -for algType in ALL_SOLVER_TYPES +for algType in GENERAL_SOLVER_TYPES @eval function SciMLBase.__init( prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... ) @@ -200,13 +200,4 @@ nodual_value(x) = x nodual_value(x::Dual) = ForwardDiff.value(x) nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) -""" - pickchunksize(x) = pickchunksize(length(x)) - pickchunksize(x::Int) - -Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length. -""" -@inline pickchunksize(x) = pickchunksize(length(x)) -@inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x) - end diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index 649ac79d2..8fd4b1947 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -84,4 +84,6 @@ export DescentResult, SteepestDescent, NewtonDescent, DampedNewtonDescent, Dogle export NonlinearSolvePolyAlgorithm +export pickchunksize + end diff --git a/lib/NonlinearSolveBase/src/common_defaults.jl b/lib/NonlinearSolveBase/src/common_defaults.jl index 4518063a5..5a5433ee3 100644 --- a/lib/NonlinearSolveBase/src/common_defaults.jl +++ b/lib/NonlinearSolveBase/src/common_defaults.jl @@ -45,3 +45,12 @@ function get_tolerance(::Union{StaticArray, Number}, ::Nothing, ::Type{T}) where # Rational numbers can throw an error if used inside GPU Kernels return T(real(oneunit(T)) * (eps(real(one(T)))^(real(T)(0.8)))) end + +""" + pickchunksize(x) = pickchunksize(length(x)) + pickchunksize(x::Int) + +Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length. +""" +@inline pickchunksize(x) = pickchunksize(length(x)) +@inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x) diff --git a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl index 1f480fb4b..15b99c5d1 100644 --- a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl +++ b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl @@ -29,7 +29,7 @@ using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode, using SciMLJacobianOperators: VecJacOperator, JacVecOperator, StatefulJacobianOperator using FiniteDiff: FiniteDiff # Default Finite Difference Method -using ForwardDiff: ForwardDiff # Default Forward Mode AD +using ForwardDiff: ForwardDiff, Dual # Default Forward Mode AD include("raphson.jl") include("gauss_newton.jl") @@ -41,6 +41,8 @@ include("poly_algs.jl") include("solve.jl") +include("forward_diff.jl") + @setup_workload begin nonlinear_functions = ( (NonlinearFunction{false, NoSpecialize}((u, p) -> u .* u .- p), 0.1), diff --git a/lib/NonlinearSolveFirstOrder/src/forward_diff.jl b/lib/NonlinearSolveFirstOrder/src/forward_diff.jl new file mode 100644 index 000000000..86f4b072a --- /dev/null +++ b/lib/NonlinearSolveFirstOrder/src/forward_diff.jl @@ -0,0 +1,34 @@ +const DualNonlinearProblem = NonlinearProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualAbstractNonlinearProblem = Union{ + DualNonlinearProblem, DualNonlinearLeastSquaresProblem +} + +function SciMLBase.__init( + prob::DualAbstractNonlinearProblem, alg::GeneralizedFirstOrderAlgorithm, args...; kwargs... +) + p = NonlinearSolveBase.nodual_value(prob.p) + newprob = SciMLBase.remake(prob; u0 = NonlinearSolveBase.nodual_value(prob.u0), p) + cache = init(newprob, alg, args...; kwargs...) + return NonlinearSolveForwardDiffCache( + cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p) + ) +end + +function SciMLBase.__solve( + prob::DualAbstractNonlinearProblem, alg::GeneralizedFirstOrderAlgorithm, args...; kwargs... +) + sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( + prob, alg, args...; kwargs... + ) + dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original + ) +end diff --git a/lib/NonlinearSolveQuasiNewton/Project.toml b/lib/NonlinearSolveQuasiNewton/Project.toml index 2f00863d8..4912e9070 100644 --- a/lib/NonlinearSolveQuasiNewton/Project.toml +++ b/lib/NonlinearSolveQuasiNewton/Project.toml @@ -18,6 +18,12 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +[weakdeps] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + +[extensions] +NonlinearSolveQuasiNewtonForwardDiffExt = "ForwardDiff" + [compat] ADTypes = "1.9.0" Aqua = "0.8" diff --git a/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl b/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl new file mode 100644 index 000000000..afba60d43 --- /dev/null +++ b/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl @@ -0,0 +1,47 @@ +module NonlinearSolveQuasiNewtonForwardDiffExt + +using CommonSolve: CommonSolve, solve +using ForwardDiff: ForwardDiff, Dual +using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, + NonlinearProblem, NonlinearLeastSquaresProblem, remake + +using NonlinearSolveBase: NonlinearSolveBase + +using NonlinearSolveQuasiNewton: QuasiNewtonAlgorithm + +const DualNonlinearProblem = NonlinearProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualAbstractNonlinearProblem = Union{ + DualNonlinearProblem, DualNonlinearLeastSquaresProblem +} + +function SciMLBase.__solve( + prob::DualAbstractNonlinearProblem, alg::QuasiNewtonAlgorithm, args...; kwargs... +) + sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( + prob, alg, args...; kwargs... + ) + dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original + ) +end + +function SciMLBase.__init( + prob::DualAbstractNonlinearProblem, alg::QuasiNewtonAlgorithm, args...; kwargs... +) + p = nodual_value(prob.p) + newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p) + cache = init(newprob, alg, args...; kwargs...) + return NonlinearSolveForwardDiffCache( + cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p) + ) +end + +end diff --git a/lib/NonlinearSolveSpectralMethods/Project.toml b/lib/NonlinearSolveSpectralMethods/Project.toml index bb9367554..7175c5ea9 100644 --- a/lib/NonlinearSolveSpectralMethods/Project.toml +++ b/lib/NonlinearSolveSpectralMethods/Project.toml @@ -14,6 +14,12 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +[weakdeps] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + +[extensions] +NonlinearSolveSpectralMethodsForwardDiffExt = "ForwardDiff" + [compat] Aqua = "0.8" BenchmarkTools = "1.5.0" diff --git a/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl b/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl new file mode 100644 index 000000000..86604d7e2 --- /dev/null +++ b/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl @@ -0,0 +1,47 @@ +module NonlinearSolveSpectralMethodsForwardDiffExt + +using CommonSolve: CommonSolve, solve +using ForwardDiff: ForwardDiff, Dual +using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, + NonlinearProblem, NonlinearLeastSquaresProblem, remake + +using NonlinearSolveBase: NonlinearSolveBase + +using NonlinearSolveSpectralMethods: GeneralizedDFSane + +const DualNonlinearProblem = NonlinearProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualAbstractNonlinearProblem = Union{ + DualNonlinearProblem, DualNonlinearLeastSquaresProblem +} + +function SciMLBase.__solve( + prob::DualAbstractNonlinearProblem, alg::GeneralizedDFSane, args...; kwargs... +) + sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( + prob, alg, args...; kwargs... + ) + dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original + ) +end + +function SciMLBase.__init( + prob::DualAbstractNonlinearProblem, alg::GeneralizedDFSane, args...; kwargs... +) + p = nodual_value(prob.p) + newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p) + cache = init(newprob, alg, args...; kwargs...) + return NonlinearSolveForwardDiffCache( + cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p) + ) +end + +end diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index c6fcc1f12..a1b759011 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -14,7 +14,7 @@ using LineSearch: BackTracking using NonlinearSolveBase: NonlinearSolveBase, InternalAPI, AbstractNonlinearSolveAlgorithm, AbstractNonlinearSolveCache, Utils, L2_NORM, enable_timer_outputs, disable_timer_outputs, - NonlinearSolvePolyAlgorithm + NonlinearSolvePolyAlgorithm, pickchunksize using Preferences: set_preferences! using SciMLBase: SciMLBase, NLStats, ReturnCode, AbstractNonlinearProblem, @@ -53,14 +53,7 @@ include("extension_algs.jl") include("default.jl") -const ALL_SOLVER_TYPES = [ - Nothing, AbstractNonlinearSolveAlgorithm, - GeneralizedDFSane, GeneralizedFirstOrderAlgorithm, QuasiNewtonAlgorithm, - LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL, - SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL, - CMINPACK, PETScSNES, - NonlinearSolvePolyAlgorithm -] +include("forward_diff.jl") @setup_workload begin nonlinear_functions = ( diff --git a/src/forward_diff.jl b/src/forward_diff.jl new file mode 100644 index 000000000..76fdf6f52 --- /dev/null +++ b/src/forward_diff.jl @@ -0,0 +1,44 @@ +const EXTENSION_SOLVER_TYPES = [ + LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL, + SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL, + CMINPACK, PETScSNES +] + +const DualNonlinearProblem = NonlinearProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} +} where {iip, T, V, P} +const DualAbstractNonlinearProblem = Union{ + DualNonlinearProblem, DualNonlinearLeastSquaresProblem +} + +for algType in EXTENSION_SOLVER_TYPES + @eval function SciMLBase.__init( + prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... + ) + p = NonlinearSolveBase.nodual_value(prob.p) + newprob = SciMLBase.remake(prob; u0 = NonlinearSolveBase.nodual_value(prob.u0), p) + cache = init(newprob, alg, args...; kwargs...) + return NonlinearSolveForwardDiffCache( + cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p) + ) + end +end + +for algType in EXTENSION_SOLVER_TYPES + @eval function SciMLBase.__solve( + prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... + ) + sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( + prob, alg, args...; kwargs... + ) + dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original + ) + end +end From 8bacad137bf078fe2058e3b2affce9cd66f3fe82 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Wed, 4 Dec 2024 01:40:30 +0800 Subject: [PATCH 3/9] Fix Aqua error --- .../ext/NonlinearSolveBaseForwardDiffExt.jl | 43 ++++++++++--------- .../src/NonlinearSolveBase.jl | 2 + lib/NonlinearSolveBase/src/common_defaults.jl | 9 ---- lib/NonlinearSolveBase/src/forward_diff.jl | 8 ++++ lib/NonlinearSolveBase/src/public.jl | 2 + .../src/NonlinearSolveFirstOrder.jl | 2 +- ...NonlinearSolveQuasiNewtonForwardDiffExt.jl | 5 +-- .../Project.toml | 1 + ...inearSolveSpectralMethodsForwardDiffExt.jl | 5 +-- 9 files changed, 40 insertions(+), 37 deletions(-) create mode 100644 lib/NonlinearSolveBase/src/forward_diff.jl diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl index 203d06f14..95d077614 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl @@ -2,7 +2,7 @@ module NonlinearSolveBaseForwardDiffExt using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff using ArrayInterface: ArrayInterface -using CommonSolve: CommonSolve, solve +using CommonSolve: CommonSolve, solve, solve!, init using ConcreteStructs: @concrete using DifferentiationInterface: DifferentiationInterface using FastClosures: @closure @@ -10,14 +10,14 @@ using ForwardDiff: ForwardDiff, Dual using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem, remake -using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, - AbstractNonlinearSolveAlgorithm, Utils, InternalAPI, - AbstractNonlinearSolveCache, NonlinearSolvePolyAlgorithm +using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils, InternalAPI, + AbstractNonlinearSolveCache, NonlinearSolvePolyAlgorithm, + NonlinearSolveForwardDiffCache const DI = DifferentiationInterface const GENERAL_SOLVER_TYPES = [ - Nothing, AbstractNonlinearSolveAlgorithm, NonlinearSolvePolyAlgorithm + Nothing, NonlinearSolvePolyAlgorithm ] const DualNonlinearProblem = NonlinearProblem{ @@ -135,24 +135,16 @@ for algType in GENERAL_SOLVER_TYPES end end -@concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache - cache - prob - alg - p - values_p - partials_p -end - function InternalAPI.reinit!( cache::NonlinearSolveForwardDiffCache, args...; p = cache.p, u0 = NonlinearSolveBase.get_u(cache.cache), kwargs... ) InternalAPI.reinit!( - cache.cache; p = nodual_value(p), u0 = nodual_value(u0), kwargs... + cache.cache; p = NonlinearSolveBase.nodual_value(p), + u0 = NonlinearSolveBase.nodual_value(u0), kwargs... ) cache.p = p - cache.values_p = nodual_value(p) + cache.values_p = NonlinearSolveBase.nodual_value(p) cache.partials_p = ForwardDiff.partials(p) return cache end @@ -161,8 +153,8 @@ for algType in GENERAL_SOLVER_TYPES @eval function SciMLBase.__init( prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... ) - p = nodual_value(prob.p) - newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p) + p = NonlinearSolveBase.nodual_value(prob.p) + newprob = SciMLBase.remake(prob; u0 = NonlinearSolveBase.nodual_value(prob.u0), p) cache = init(newprob, alg, args...; kwargs...) return NonlinearSolveForwardDiffCache( cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p) @@ -196,8 +188,17 @@ function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache) ) end -nodual_value(x) = x -nodual_value(x::Dual) = ForwardDiff.value(x) -nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) +NonlinearSolveBase.nodual_value(x) = x +NonlinearSolveBase.nodual_value(x::Dual) = ForwardDiff.value(x) +NonlinearSolveBase.nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) + +""" + pickchunksize(x) = pickchunksize(length(x)) + pickchunksize(x::Int) + +Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length. +""" +@inline NonlinearSolveBase.pickchunksize(x) = pickchunksize(length(x)) +@inline NonlinearSolveBase.pickchunksize(x::Int) = ForwardDiff.pickchunksize(x) end diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index 8fd4b1947..df65e1fed 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -58,6 +58,8 @@ include("descent/geodesic_acceleration.jl") include("solve.jl") +include("forward_diff.jl") + # Unexported Public API @compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance)) @compat(public, (nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution)) diff --git a/lib/NonlinearSolveBase/src/common_defaults.jl b/lib/NonlinearSolveBase/src/common_defaults.jl index 5a5433ee3..4518063a5 100644 --- a/lib/NonlinearSolveBase/src/common_defaults.jl +++ b/lib/NonlinearSolveBase/src/common_defaults.jl @@ -45,12 +45,3 @@ function get_tolerance(::Union{StaticArray, Number}, ::Nothing, ::Type{T}) where # Rational numbers can throw an error if used inside GPU Kernels return T(real(oneunit(T)) * (eps(real(one(T)))^(real(T)(0.8)))) end - -""" - pickchunksize(x) = pickchunksize(length(x)) - pickchunksize(x::Int) - -Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length. -""" -@inline pickchunksize(x) = pickchunksize(length(x)) -@inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x) diff --git a/lib/NonlinearSolveBase/src/forward_diff.jl b/lib/NonlinearSolveBase/src/forward_diff.jl new file mode 100644 index 000000000..a588aa52d --- /dev/null +++ b/lib/NonlinearSolveBase/src/forward_diff.jl @@ -0,0 +1,8 @@ +@concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache + cache + prob + alg + p + values_p + partials_p +end diff --git a/lib/NonlinearSolveBase/src/public.jl b/lib/NonlinearSolveBase/src/public.jl index d076f7873..b68e3806f 100644 --- a/lib/NonlinearSolveBase/src/public.jl +++ b/lib/NonlinearSolveBase/src/public.jl @@ -11,6 +11,8 @@ function nonlinearsolve_dual_solution end function nonlinearsolve_∂f_∂p end function nonlinearsolve_∂f_∂u end function nlls_generate_vjp_function end +function nodual_value end +function pickchunksize end # Nonlinear Solve Termination Conditions abstract type AbstractNonlinearTerminationMode end diff --git a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl index 15b99c5d1..666cc7435 100644 --- a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl +++ b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl @@ -22,7 +22,7 @@ using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm, Utils, InternalAPI, get_timer_output, @static_timeit, update_trace!, L2_NORM, NonlinearSolvePolyAlgorithm, NewtonDescent, DampedNewtonDescent, GeodesicAcceleration, - Dogleg + Dogleg, NonlinearSolveForwardDiffCache using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode, NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem, NoSpecialize diff --git a/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl b/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl index afba60d43..ca4e7bb94 100644 --- a/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl +++ b/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl @@ -1,9 +1,8 @@ module NonlinearSolveQuasiNewtonForwardDiffExt -using CommonSolve: CommonSolve, solve +using CommonSolve: CommonSolve, init using ForwardDiff: ForwardDiff, Dual -using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, - NonlinearProblem, NonlinearLeastSquaresProblem, remake +using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem using NonlinearSolveBase: NonlinearSolveBase diff --git a/lib/NonlinearSolveSpectralMethods/Project.toml b/lib/NonlinearSolveSpectralMethods/Project.toml index 7175c5ea9..a248be107 100644 --- a/lib/NonlinearSolveSpectralMethods/Project.toml +++ b/lib/NonlinearSolveSpectralMethods/Project.toml @@ -27,6 +27,7 @@ CommonSolve = "0.2.4" ConcreteStructs = "0.2.3" DiffEqBase = "6.158.3" ExplicitImports = "1.5" +ForwardDiff = "0.10.36" Hwloc = "3" InteractiveUtils = "<0.0.1, 1" LineSearch = "0.1.4" diff --git a/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl b/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl index 86604d7e2..5dfc559f6 100644 --- a/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl +++ b/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl @@ -1,9 +1,8 @@ module NonlinearSolveSpectralMethodsForwardDiffExt -using CommonSolve: CommonSolve, solve +using CommonSolve: CommonSolve, init using ForwardDiff: ForwardDiff, Dual -using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, - NonlinearProblem, NonlinearLeastSquaresProblem, remake +using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem using NonlinearSolveBase: NonlinearSolveBase From b6a2406478c0eebfef845352069a28d5f6ced3b2 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Wed, 4 Dec 2024 12:39:11 +0800 Subject: [PATCH 4/9] refactor: Fix nodual_value --- lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl | 3 +-- .../ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl | 2 +- .../ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl index 95d077614..4ffafc974 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl @@ -11,8 +11,7 @@ using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem, remake using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils, InternalAPI, - AbstractNonlinearSolveCache, NonlinearSolvePolyAlgorithm, - NonlinearSolveForwardDiffCache + NonlinearSolvePolyAlgorithm, NonlinearSolveForwardDiffCache const DI = DifferentiationInterface diff --git a/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl b/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl index ca4e7bb94..7f8de5b2a 100644 --- a/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl +++ b/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl @@ -4,7 +4,7 @@ using CommonSolve: CommonSolve, init using ForwardDiff: ForwardDiff, Dual using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem -using NonlinearSolveBase: NonlinearSolveBase +using NonlinearSolveBase: NonlinearSolveBase, nondual_value using NonlinearSolveQuasiNewton: QuasiNewtonAlgorithm diff --git a/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl b/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl index 5dfc559f6..3b04751c8 100644 --- a/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl +++ b/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl @@ -4,7 +4,7 @@ using CommonSolve: CommonSolve, init using ForwardDiff: ForwardDiff, Dual using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem -using NonlinearSolveBase: NonlinearSolveBase +using NonlinearSolveBase: NonlinearSolveBase, nodual_value using NonlinearSolveSpectralMethods: GeneralizedDFSane From 10433ca51f907944988daf57a8cc5d56684541c6 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Wed, 4 Dec 2024 14:56:07 +0800 Subject: [PATCH 5/9] fix: Correct import in extensions --- .../ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl | 2 +- .../ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl b/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl index 7f8de5b2a..74ec64031 100644 --- a/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl +++ b/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl @@ -4,7 +4,7 @@ using CommonSolve: CommonSolve, init using ForwardDiff: ForwardDiff, Dual using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem -using NonlinearSolveBase: NonlinearSolveBase, nondual_value +using NonlinearSolveBase: NonlinearSolveBase, NonlinearSolveForwardDiffCache, nodual_value using NonlinearSolveQuasiNewton: QuasiNewtonAlgorithm diff --git a/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl b/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl index 3b04751c8..930c4861c 100644 --- a/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl +++ b/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl @@ -4,7 +4,7 @@ using CommonSolve: CommonSolve, init using ForwardDiff: ForwardDiff, Dual using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem -using NonlinearSolveBase: NonlinearSolveBase, nodual_value +using NonlinearSolveBase: NonlinearSolveBase, NonlinearSolveForwardDiffCache, nodual_value using NonlinearSolveSpectralMethods: GeneralizedDFSane From dececf8749a7fe47c7a6cb5ab2c6bb55b85c1309 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Wed, 4 Dec 2024 15:46:23 +0800 Subject: [PATCH 6/9] fix: Fix pickchunksize usage --- lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl index 4ffafc974..2ae7dbbaa 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl @@ -6,7 +6,7 @@ using CommonSolve: CommonSolve, solve, solve!, init using ConcreteStructs: @concrete using DifferentiationInterface: DifferentiationInterface using FastClosures: @closure -using ForwardDiff: ForwardDiff, Dual +using ForwardDiff: ForwardDiff, Dual, pickchunksize using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem, remake From d982f98f20a945ef6b66e56d4937ccb012f39e46 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Wed, 4 Dec 2024 20:02:38 +0800 Subject: [PATCH 7/9] docs: pickchunksize is now in Base --- docs/src/basics/faq.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/basics/faq.md b/docs/src/basics/faq.md index 4d428250a..5a1ca43d4 100644 --- a/docs/src/basics/faq.md +++ b/docs/src/basics/faq.md @@ -152,7 +152,7 @@ nothing # hide ``` And boom! Type stable again. We always recommend picking the chunksize via -[`NonlinearSolve.pickchunksize`](@ref), however, if you manually specify the chunksize, it +[`NonlinearSolveBase.pickchunksize`](@ref), however, if you manually specify the chunksize, it must be `≤ length of input`. However, a very large chunksize can lead to excessive compilation times and slowdown. From 85a3f80a348b70603b7b303f1fba6ce187b80ae4 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Wed, 4 Dec 2024 20:50:23 +0800 Subject: [PATCH 8/9] docs: Fix pickchunksize docs --- docs/src/basics/faq.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/basics/faq.md b/docs/src/basics/faq.md index 5a1ca43d4..9aabc203b 100644 --- a/docs/src/basics/faq.md +++ b/docs/src/basics/faq.md @@ -157,5 +157,5 @@ must be `≤ length of input`. However, a very large chunksize can lead to exces compilation times and slowdown. ```@docs -NonlinearSolve.pickchunksize +NonlinearSolveBase.pickchunksize ``` From 4056565f5aee9a431cac83558d5212eca00dfd09 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Wed, 4 Dec 2024 21:25:12 +0800 Subject: [PATCH 9/9] docs: Move docstrings from ext to public --- .../ext/NonlinearSolveBaseForwardDiffExt.jl | 6 ------ lib/NonlinearSolveBase/src/public.jl | 7 +++++++ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl index 2ae7dbbaa..717daa8e4 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl @@ -191,12 +191,6 @@ NonlinearSolveBase.nodual_value(x) = x NonlinearSolveBase.nodual_value(x::Dual) = ForwardDiff.value(x) NonlinearSolveBase.nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) -""" - pickchunksize(x) = pickchunksize(length(x)) - pickchunksize(x::Int) - -Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length. -""" @inline NonlinearSolveBase.pickchunksize(x) = pickchunksize(length(x)) @inline NonlinearSolveBase.pickchunksize(x::Int) = ForwardDiff.pickchunksize(x) diff --git a/lib/NonlinearSolveBase/src/public.jl b/lib/NonlinearSolveBase/src/public.jl index b68e3806f..a9bae2a5e 100644 --- a/lib/NonlinearSolveBase/src/public.jl +++ b/lib/NonlinearSolveBase/src/public.jl @@ -12,6 +12,13 @@ function nonlinearsolve_∂f_∂p end function nonlinearsolve_∂f_∂u end function nlls_generate_vjp_function end function nodual_value end + +""" + pickchunksize(x) = pickchunksize(length(x)) + pickchunksize(x::Int) + +Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length. +""" function pickchunksize end # Nonlinear Solve Termination Conditions