Skip to content

Commit

Permalink
refactor: Move NonlinearSolvePolyAlgorithm to Base (#494)
Browse files Browse the repository at this point in the history
* refactor: Move NonlinearSolvePolyAlgorithm to Base

* test: Make NonlinearSolve use 1.3 Base

* refactor: Remove unnecessary snippet

* refactor: Don't use duplicate solve

* refactor: Test Base export NonlinearSolvePolyAlgorithm
  • Loading branch information
ErikQQY authored Nov 6, 2024
1 parent 037a07c commit 6f043bf
Show file tree
Hide file tree
Showing 8 changed files with 552 additions and 549 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ NLSolvers = "0.5"
NLsolve = "4.5"
NaNMath = "1"
NonlinearProblemLibrary = "0.1.2"
NonlinearSolveBase = "1.2"
NonlinearSolveBase = "1.3"
NonlinearSolveFirstOrder = "1"
NonlinearSolveQuasiNewton = "1"
NonlinearSolveSpectralMethods = "1"
Expand Down
2 changes: 1 addition & 1 deletion lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolveBase"
uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "1.2.0"
version = "1.3.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 3 additions & 0 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ include("linear_solve.jl")
include("timer_outputs.jl")
include("tracing.jl")
include("wrappers.jl")
include("polyalg.jl")

include("descent/common.jl")
include("descent/newton.jl")
Expand Down Expand Up @@ -81,4 +82,6 @@ export RelTerminationMode, AbsTerminationMode,
export DescentResult, SteepestDescent, NewtonDescent, DampedNewtonDescent, Dogleg,
GeodesicAcceleration

export NonlinearSolvePolyAlgorithm

end
202 changes: 202 additions & 0 deletions lib/NonlinearSolveBase/src/polyalg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
"""
NonlinearSolvePolyAlgorithm(algs; start_index::Int = 1)
A general way to define PolyAlgorithms for `NonlinearProblem` and
`NonlinearLeastSquaresProblem`. This is a container for a tuple of algorithms that will be
tried in order until one succeeds. If none succeed, then the algorithm with the lowest
residual is returned.
### Arguments
- `algs`: a tuple of algorithms to try in-order! (If this is not a Tuple, then the
returned algorithm is not type-stable).
### Keyword Arguments
- `start_index`: the index to start at. Defaults to `1`.
### Example
```julia
using NonlinearSolve
alg = NonlinearSolvePolyAlgorithm((NewtonRaphson(), Broyden()))
```
"""
@concrete struct NonlinearSolvePolyAlgorithm <: AbstractNonlinearSolveAlgorithm
static_length <: Val
algs <: Tuple
start_index::Int
end

function NonlinearSolvePolyAlgorithm(algs; start_index::Int = 1)
@assert 0 < start_index length(algs)
algs = Tuple(algs)
return NonlinearSolvePolyAlgorithm(Val(length(algs)), algs, start_index)
end

@concrete mutable struct NonlinearSolvePolyAlgorithmCache <: AbstractNonlinearSolveCache
static_length <: Val
prob <: AbstractNonlinearProblem

caches <: Tuple
alg <: NonlinearSolvePolyAlgorithm

best::Int
current::Int
nsteps::Int

stats::NLStats
total_time::Float64
maxtime

retcode::ReturnCode.T
force_stop::Bool

maxiters::Int
internalnorm

u0
u0_aliased
alias_u0::Bool
end

function SII.symbolic_container(cache::NonlinearSolvePolyAlgorithmCache)
return cache.caches[cache.current]
end
SII.state_values(cache::NonlinearSolvePolyAlgorithmCache) = cache.u0

function Base.show(io::IO, ::MIME"text/plain", cache::NonlinearSolvePolyAlgorithmCache)
println(io, "NonlinearSolvePolyAlgorithmCache with \
$(Utils.unwrap_val(cache.static_length)) algorithms:")
best_alg = ifelse(cache.best == -1, "nothing", cache.best)
println(io, " Best Algorithm: $(best_alg)")
println(
io, " Current Algorithm: [$(cache.current) / $(Utils.unwrap_val(cache.static_length))]"
)
println(io, " nsteps: $(cache.nsteps)")
println(io, " retcode: $(cache.retcode)")
print(io, " Current Cache: ")
NonlinearSolveBase.show_nonlinearsolve_cache(io, cache.caches[cache.current], 4)
end

function InternalAPI.reinit!(
cache::NonlinearSolvePolyAlgorithmCache, args...; p = cache.p, u0 = cache.u0
)
foreach(cache.caches) do cache
InternalAPI.reinit!(cache, args...; p, u0)
end
cache.current = cache.alg.start_index
InternalAPI.reinit!(cache.stats)
cache.nsteps = 0
cache.total_time = 0.0
end

function SciMLBase.__init(
prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm, args...;
stats = NLStats(0, 0, 0, 0, 0), maxtime = nothing, maxiters = 1000,
internalnorm = L2_NORM, alias_u0 = false, verbose = true, kwargs...
)
if alias_u0 && !ArrayInterface.ismutable(prob.u0)
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \
immutable (checked using `ArrayInterface.ismutable`)."
alias_u0 = false # If immutable don't care about aliasing
end

u0 = prob.u0
u0_aliased = alias_u0 ? copy(u0) : u0
alias_u0 && (prob = SciMLBase.remake(prob; u0 = u0_aliased))

return NonlinearSolvePolyAlgorithmCache(
alg.static_length, prob,
map(alg.algs) do solver
SciMLBase.__init(
prob, solver, args...;
stats, maxtime, internalnorm, alias_u0, verbose, kwargs...
)
end,
alg, -1, alg.start_index, 0, stats, 0.0, maxtime,
ReturnCode.Default, false, maxiters, internalnorm,
u0, u0_aliased, alias_u0
)
end

@generated function InternalAPI.step!(
cache::NonlinearSolvePolyAlgorithmCache{Val{N}}, args...; kwargs...
) where {N}
calls = []
cache_syms = [gensym("cache") for i in 1:N]
for i in 1:N
push!(calls,
quote
$(cache_syms[i]) = cache.caches[$(i)]
if $(i) == cache.current
InternalAPI.step!($(cache_syms[i]), args...; kwargs...)
$(cache_syms[i]).nsteps += 1
if !NonlinearSolveBase.not_terminated($(cache_syms[i]))
if SciMLBase.successful_retcode($(cache_syms[i]).retcode)
cache.best = $(i)
cache.force_stop = true
cache.retcode = $(cache_syms[i]).retcode
else
cache.current = $(i + 1)
end
end
return
end
end)
end

push!(calls, quote
if !(1 cache.current length(cache.caches))
minfu, idx = findmin_caches(cache.prob, cache.caches)
cache.best = idx
cache.retcode = cache.caches[idx].retcode
cache.force_stop = true
return
end
end)

return Expr(:block, calls...)
end

# Original is often determined on runtime information especially for PolyAlgorithms so it
# is best to never specialize on that
function build_solution_less_specialize(
prob::AbstractNonlinearProblem, alg, u, resid;
retcode = ReturnCode.Default, original = nothing, left = nothing,
right = nothing, stats = nothing, trace = nothing, kwargs...
)
return SciMLBase.NonlinearSolution{
eltype(eltype(u)), ndims(u), typeof(u), typeof(resid), typeof(prob),
typeof(alg), Any, typeof(left), typeof(stats), typeof(trace)
}(
u, resid, prob, alg, retcode, original, left, right, stats, trace
)
end

function findmin_caches(prob::AbstractNonlinearProblem, caches)
resids = map(caches) do cache
cache === nothing && return nothing
return NonlinearSolveBase.get_fu(cache)
end
return findmin_resids(prob, resids)
end

@views function findmin_resids(prob::AbstractNonlinearProblem, caches)
norm_fn = prob isa NonlinearLeastSquaresProblem ? Base.Fix2(norm, 2) :
Base.Fix2(norm, Inf)
idx = findfirst(Base.Fix2(!==, nothing), caches)
# This is an internal function so we assume that inputs are consistent and there is
# atleast one non-`nothing` value
fx_idx = norm_fn(caches[idx])
idx == length(caches) && return fx_idx, idx
fmin = @closure xᵢ -> begin
xᵢ === nothing && return oftype(fx_idx, Inf)
fx = norm_fn(xᵢ)
return ifelse(isnan(fx), oftype(fx, Inf), fx)
end
x_min, x_min_idx = findmin(fmin, caches[(idx + 1):length(caches)])
x_min < fx_idx && return x_min, x_min_idx + idx
return fx_idx, idx
end
Loading

2 comments on commit 6f043bf

@avik-pal
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=lib/NonlinearSolveBase

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/118813

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a NonlinearSolveBase-v1.3.0 -m "<description of version>" 6f043bfd5e342a36b43678b8e10bd4ea81a52a95
git push origin NonlinearSolveBase-v1.3.0

Please sign in to comment.