Skip to content

Commit 173dd01

Browse files
committed
refactor: migrate to LineSearch.jl
1 parent 05aa3db commit 173dd01

12 files changed

+65
-483
lines changed

Diff for: Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1313
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1414
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1515
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
16+
LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
1617
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
1718
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1819
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
@@ -78,6 +79,7 @@ Hwloc = "3"
7879
InteractiveUtils = "<0.0.1, 1"
7980
LazyArrays = "1.8.2, 2"
8081
LeastSquaresOptim = "0.8.5"
82+
LineSearch = "0.1"
8183
LineSearches = "7.2"
8284
LinearAlgebra = "1.10"
8385
LinearSolve = "2.30"

Diff for: docs/src/devdocs/internal_interfaces.md

-7
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,6 @@ NonlinearSolve.AbstractDampingFunction
3838
NonlinearSolve.AbstractDampingFunctionCache
3939
```
4040

41-
## Line Search
42-
43-
```@docs
44-
NonlinearSolve.AbstractNonlinearSolveLineSearchAlgorithm
45-
NonlinearSolve.AbstractNonlinearSolveLineSearchCache
46-
```
47-
4841
## Trust Region
4942

5043
```@docs

Diff for: src/NonlinearSolve.jl

+5-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ using LazyArrays: LazyArrays, ApplyArray, cache
3030
using LinearAlgebra: LinearAlgebra, ColumnNorm, Diagonal, I, LowerTriangular, Symmetric,
3131
UpperTriangular, axpy!, cond, diag, diagind, dot, issuccess, istril,
3232
istriu, lu, mul!, norm, pinv, tril!, triu!
33+
using LineSearch: LineSearch, AbstractLineSearchAlgorithm, AbstractLineSearchCache,
34+
NoLineSearch, RobustNonMonotoneLineSearch
3335
using LineSearches: LineSearches
3436
using LinearSolve: LinearSolve, LUFactorization, QRFactorization, ComposePreconditioner,
3537
InvPreconditioner, needs_concrete_A, AbstractFactorization,
@@ -170,8 +172,9 @@ export NewtonDescent, SteepestDescent, Dogleg, DampedNewtonDescent, GeodesicAcce
170172

171173
# Globalization
172174
## Line Search Algorithms
173-
export LineSearchesJL, NoLineSearch, RobustNonMonotoneLineSearch, LiFukushimaLineSearch
174-
export Static, HagerZhang, MoreThuente, StrongWolfe, BackTracking
175+
export LineSearchesJL, LiFukushimaLineSearch # FIXME: deprecated. use LineSearch.jl directly
176+
export Static, HagerZhang, MoreThuente, StrongWolfe, BackTracking # FIXME: deprecated
177+
export NoLineSearch, RobustNonMonotoneLineSearch
175178
## Trust Region Algorithms
176179
export RadiusUpdateSchemes
177180

Diff for: src/abstract_types.jl

+3-19
Original file line numberDiff line numberDiff line change
@@ -106,22 +106,6 @@ function last_step_accepted(cache::AbstractDescentCache)
106106
return true
107107
end
108108

109-
"""
110-
AbstractNonlinearSolveLineSearchAlgorithm
111-
112-
Abstract Type for all Line Search Algorithms used in NonlinearSolve.jl.
113-
114-
### `__internal_init` specification
115-
116-
```julia
117-
__internal_init(
118-
prob::AbstractNonlinearProblem, alg::AbstractNonlinearSolveLineSearchAlgorithm, f::F,
119-
fu, u, p, args...; internalnorm::IN = DEFAULT_NORM, kwargs...) where {F, IN} -->
120-
AbstractNonlinearSolveLineSearchCache
121-
```
122-
"""
123-
abstract type AbstractNonlinearSolveLineSearchAlgorithm end
124-
125109
"""
126110
AbstractNonlinearSolveLineSearchCache
127111
@@ -512,9 +496,9 @@ SciMLBase.isinplace(::AbstractNonlinearSolveJacobianCache{iip}) where {iip} = ii
512496
abstract type AbstractNonlinearSolveTraceLevel end
513497

514498
# Default Printing
515-
for aType in (AbstractTrustRegionMethod, AbstractNonlinearSolveLineSearchAlgorithm,
516-
AbstractResetCondition, AbstractApproximateJacobianUpdateRule,
517-
AbstractDampingFunction, AbstractNonlinearSolveExtensionAlgorithm)
499+
for aType in (AbstractTrustRegionMethod, AbstractResetCondition,
500+
AbstractApproximateJacobianUpdateRule, AbstractDampingFunction,
501+
AbstractNonlinearSolveExtensionAlgorithm)
518502
@eval function Base.show(io::IO, alg::$(aType))
519503
print(io, "$(nameof(typeof(alg)))()")
520504
end

Diff for: src/algorithms/klement.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ over this.
2727
function Klement(; max_resets::Int = 100, linsolve = nothing, alpha = nothing,
2828
linesearch = NoLineSearch(), precs = DEFAULT_PRECS,
2929
autodiff = nothing, init_jacobian::Val{IJ} = Val(:identity)) where {IJ}
30-
if !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
30+
if !(linesearch isa AbstractLineSearchAlgorithm)
3131
Base.depwarn(
3232
"Passing in a `LineSearches.jl` algorithm directly is deprecated. \
3333
Please use `LineSearchesJL` instead.", :Klement)

Diff for: src/algorithms/pseudo_transient.jl

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
PseudoTransient(; concrete_jac = nothing, linsolve = nothing,
3-
linesearch::AbstractNonlinearSolveLineSearchAlgorithm = NoLineSearch(),
4-
precs = DEFAULT_PRECS, autodiff = nothing)
3+
linesearch = NoLineSearch(), precs = DEFAULT_PRECS, autodiff = nothing)
54
65
An implementation of PseudoTransient Method [coffey2003pseudotransient](@cite) that is used
76
to solve steady state problems in an accelerated manner. It uses an adaptive time-stepping
@@ -16,8 +15,8 @@ This implementation specifically uses "switched evolution relaxation"
1615
you are going to need more iterations to converge but it can be more stable.
1716
"""
1817
function PseudoTransient(; concrete_jac = nothing, linsolve = nothing,
19-
linesearch::AbstractNonlinearSolveLineSearchAlgorithm = NoLineSearch(),
20-
precs = DEFAULT_PRECS, autodiff = nothing, alpha_initial = 1e-3)
18+
linesearch = NoLineSearch(), precs = DEFAULT_PRECS, autodiff = nothing,
19+
alpha_initial = 1e-3)
2120
descent = DampedNewtonDescent(; linsolve, precs, initial_damping = alpha_initial,
2221
damping_fn = SwitchedEvolutionRelaxation())
2322
return GeneralizedFirstOrderAlgorithm(;

Diff for: src/core/approximate_jacobian.jl

+6-4
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ function ApproximateJacobianSolveAlgorithm{concrete_jac, name}(;
5959
linesearch = missing, trustregion = missing, descent, update_rule,
6060
reinit_rule, initialization, max_resets::Int = typemax(Int),
6161
max_shrink_times::Int = typemax(Int)) where {concrete_jac, name}
62-
if linesearch !== missing && !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
62+
if linesearch !== missing && !(linesearch isa AbstractLineSearchAlgorithm)
6363
Base.depwarn("Passing in a `LineSearches.jl` algorithm directly is deprecated. \
6464
Please use `LineSearchesJL` instead.",
6565
:GeneralizedFirstOrderAlgorithm)
@@ -199,8 +199,8 @@ function SciMLBase.__init(
199199
if alg.linesearch !== missing
200200
supports_line_search(alg.descent) || error("Line Search not supported by \
201201
$(alg.descent).")
202-
linesearch_cache = __internal_init(
203-
prob, alg.linesearch, f, fu, u, p; stats, internalnorm, kwargs...)
202+
linesearch_cache = init(
203+
prob, alg.linesearch, fu, u; stats, internalnorm, kwargs...)
204204
GB = :LineSearch
205205
end
206206

@@ -317,7 +317,9 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip};
317317
if descent_result.success
318318
if GB === :LineSearch
319319
@static_timeit cache.timer "linesearch" begin
320-
needs_reset, α = __internal_solve!(cache.linesearch_cache, cache.u, δu)
320+
linesearch_sol = solve!(cache.linesearch_cache, cache.u, δu)
321+
needs_reset = !SciMLBase.successful_retcode(linesearch_sol.retcode)
322+
α = linesearch_sol.step_size
321323
end
322324
if needs_reset && cache.steps_since_last_reset > 5 # Reset after a burn-in period
323325
cache.force_reinit = true

Diff for: src/core/generalized_first_order.jl

+11-5
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ function GeneralizedFirstOrderAlgorithm{concrete_jac, name}(;
6666
jacobian_ad !== nothing && ADTypes.mode(jacobian_ad) isa ADTypes.ReverseMode,
6767
jacobian_ad, nothing))
6868

69-
if linesearch !== missing && !(linesearch isa AbstractNonlinearSolveLineSearchAlgorithm)
69+
if linesearch !== missing && !(linesearch isa AbstractLineSearchAlgorithm)
7070
Base.depwarn("Passing in a `LineSearches.jl` algorithm directly is deprecated. \
7171
Please use `LineSearchesJL` instead.",
7272
:GeneralizedFirstOrderAlgorithm)
@@ -199,8 +199,13 @@ function SciMLBase.__init(
199199
if alg.linesearch !== missing
200200
supports_line_search(alg.descent) || error("Line Search not supported by \
201201
$(alg.descent).")
202-
linesearch_cache = __internal_init(
203-
prob, alg.linesearch, f, fu, u, p; stats, internalnorm, kwargs...)
202+
linesearch_ad = alg.forward_ad === nothing ?
203+
(alg.reverse_ad === nothing ? alg.jacobian_ad :
204+
alg.reverse_ad) : alg.forward_ad
205+
linesearch_ad = get_concrete_forward_ad(
206+
linesearch_ad, prob, False; check_forward_mode = false)
207+
linesearch_cache = init(
208+
prob, alg.linesearch, fu, u; stats, autodiff = linesearch_ad, kwargs...)
204209
GB = :LineSearch
205210
end
206211

@@ -264,8 +269,9 @@ function __step!(cache::GeneralizedFirstOrderAlgorithmCache{iip, GB};
264269
cache.make_new_jacobian = true
265270
if GB === :LineSearch
266271
@static_timeit cache.timer "linesearch" begin
267-
linesearch_failed, α = __internal_solve!(
268-
cache.linesearch_cache, cache.u, δu)
272+
linesearch_sol = solve!(cache.linesearch_cache, cache.u, δu)
273+
linesearch_failed = !SciMLBase.successful_retcode(linesearch_sol.retcode)
274+
α = linesearch_sol.step_size
269275
end
270276
if linesearch_failed
271277
cache.retcode = ReturnCode.InternalLineSearchFailed

Diff for: src/core/spectral_methods.jl

+7-7
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@ Method.
99
1010
### Arguments
1111
12-
- `linesearch`: Globalization using a Line Search Method. This needs to follow the
13-
[`NonlinearSolve.AbstractNonlinearSolveLineSearchAlgorithm`](@ref) interface. This
14-
is not optional currently, but that restriction might be lifted in the future.
12+
- `linesearch`: Globalization using a Line Search Method. This is not optional currently,
13+
but that restriction might be lifted in the future.
1514
- `σ_min`: The minimum spectral parameter allowed. This is used to ensure that the
1615
spectral parameter is not too small.
1716
- `σ_max`: The maximum spectral parameter allowed. This is used to ensure that the
@@ -119,7 +118,7 @@ end
119118
function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane, args...;
120119
stats = empty_nlstats(), alias_u0 = false, maxiters = 1000,
121120
abstol = nothing, reltol = nothing, termination_condition = nothing,
122-
internalnorm::F = DEFAULT_NORM, maxtime = nothing, kwargs...) where {F}
121+
maxtime = nothing, kwargs...)
123122
timer = get_timer_output()
124123
@static_timeit timer "cache construction" begin
125124
u = __maybe_unaliased(prob.u0, alias_u0)
@@ -130,8 +129,7 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane
130129
fu = evaluate_f(prob, u)
131130
@bb fu_cache = copy(fu)
132131

133-
linesearch_cache = __internal_init(prob, alg.linesearch, prob.f, fu, u, prob.p;
134-
stats, maxiters, internalnorm, kwargs...)
132+
linesearch_cache = init(prob, alg.linesearch, fu, u; stats, kwargs...)
135133

136134
abstol, reltol, tc_cache = init_termination_cache(
137135
prob, abstol, reltol, fu, u_cache, termination_condition)
@@ -167,7 +165,9 @@ function __step!(cache::GeneralizedDFSaneCache{iip};
167165
end
168166

169167
@static_timeit cache.timer "linesearch" begin
170-
linesearch_failed, α = __internal_solve!(cache.linesearch_cache, cache.u, cache.du)
168+
linesearch_sol = solve!(cache.linesearch_cache, cache.u, cache.du)
169+
linesearch_failed = !SciMLBase.successful_retcode(linesearch_sol.retcode)
170+
α = linesearch_sol.step_size
171171
end
172172

173173
if linesearch_failed

Diff for: src/default.jl

+6-3
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,8 @@ function FastShortcutNonlinearPolyalg(
405405
else
406406
algs = (NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
407407
NewtonRaphson(; concrete_jac, linsolve, precs,
408-
linesearch = LineSearchesJL(; method = BackTracking()), autodiff),
408+
linesearch = LineSearchesJL(; method = LineSearches.BackTracking()),
409+
autodiff),
409410
TrustRegion(; concrete_jac, linsolve, precs, autodiff),
410411
TrustRegion(; concrete_jac, linsolve, precs,
411412
radius_update_scheme = RadiusUpdateSchemes.Bastin, autodiff))
@@ -426,7 +427,8 @@ function FastShortcutNonlinearPolyalg(
426427
SimpleKlement(),
427428
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
428429
NewtonRaphson(; concrete_jac, linsolve, precs,
429-
linesearch = LineSearchesJL(; method = BackTracking()), autodiff),
430+
linesearch = LineSearchesJL(; method = LineSearches.BackTracking()),
431+
autodiff),
430432
TrustRegion(; concrete_jac, linsolve, precs,
431433
radius_update_scheme = RadiusUpdateSchemes.Bastin, autodiff))
432434
end
@@ -444,7 +446,8 @@ function FastShortcutNonlinearPolyalg(
444446
Klement(; linsolve, precs, autodiff),
445447
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
446448
NewtonRaphson(; concrete_jac, linsolve, precs,
447-
linesearch = LineSearchesJL(; method = BackTracking()), autodiff),
449+
linesearch = LineSearchesJL(; method = LineSearches.BackTracking()),
450+
autodiff),
448451
TrustRegion(; concrete_jac, linsolve, precs, autodiff),
449452
TrustRegion(; concrete_jac, linsolve, precs,
450453
radius_update_scheme = RadiusUpdateSchemes.Bastin, autodiff))

0 commit comments

Comments
 (0)