Skip to content

Commit 18a2da4

Browse files
committed
Reuse NewtonDescent for MultiStepSchemes
1 parent ceeadcb commit 18a2da4

File tree

4 files changed

+107
-61
lines changed

4 files changed

+107
-61
lines changed

Diff for: Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ NonlinearSolveZygoteExt = "Zygote"
5656

5757
[compat]
5858
ADTypes = "0.2.6"
59-
Accessors = "0.1"
59+
Accessors = "0.1.32"
6060
Aqua = "0.8"
6161
ArrayInterface = "7.7"
6262
BandedMatrices = "1.4"

Diff for: src/abstract_types.jl

+28
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ Returns a result of type [`DescentResult`](@ref).
8787
- `get_du(cache, ::Val{N})`: get the `N`th descent direction.
8888
- `set_du!(cache, δu)`: set the descent direction.
8989
- `set_du!(cache, δu, ::Val{N})`: set the `N`th descent direction.
90+
- `get_internal_cache(cache, ::Val{field})`: get the internal cache field.
91+
- `get_internal_cache(cache, field::Val, ::Val{N})`: get the `N`th internal cache field.
92+
- `set_internal_cache!(cache, value, ::Val{field})`: set the internal cache field.
93+
- `set_internal_cache!(cache, value, field::Val, ::Val{N})`: set the `N`th internal cache
94+
field.
9095
- `last_step_accepted(cache)`: whether or not the last step was accepted. Checks if the
9196
cache has a `last_step_accepted` field and returns it if it does, else returns `true`.
9297
"""
@@ -98,6 +103,29 @@ SciMLBase.get_du(cache::AbstractDescentCache, ::Val{N}) where {N} = cache.δus[N
98103
set_du!(cache::AbstractDescentCache, δu) = (cache.δu = δu)
99104
set_du!(cache::AbstractDescentCache, δu, ::Val{1}) = set_du!(cache, δu)
100105
set_du!(cache::AbstractDescentCache, δu, ::Val{N}) where {N} = (cache.δus[N - 1] = δu)
106+
function get_internal_cache(cache::AbstractDescentCache, ::Val{field}) where {field}
107+
return getproperty(cache, field)
108+
end
109+
function get_internal_cache(cache::AbstractDescentCache, field::Val, ::Val{1})
110+
return get_internal_cache(cache, field)
111+
end
112+
function get_internal_cache(
113+
cache::AbstractDescentCache, ::Val{field}, ::Val{N}) where {field, N}
114+
true_field = Symbol(string(field), "s") # Julia 1.10 compiles this away
115+
return getproperty(cache, true_field)[N]
116+
end
117+
function set_internal_cache!(cache::AbstractDescentCache, value, ::Val{field}) where {field}
118+
return setproperty!(cache, field, value)
119+
end
120+
function set_internal_cache!(
121+
cache::AbstractDescentCache, value, field::Val, ::Val{1})
122+
return set_internal_cache!(cache, value, field)
123+
end
124+
function set_internal_cache!(
125+
cache::AbstractDescentCache, value, ::Val{field}, ::Val{N}) where {field, N}
126+
true_field = Symbol(string(field), "s") # Julia 1.10 compiles this away
127+
return setproperty!(cache, true_field, value, N)
128+
end
101129

102130
function last_step_accepted(cache::AbstractDescentCache)
103131
hasfield(typeof(cache), :last_step_accepted) && return cache.last_step_accepted

Diff for: src/algorithms/multistep.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
function MultiStepNonlinearSolver(; concrete_jac = nothing, linsolve = nothing,
22
scheme = MSS.PotraPtak3, precs = DEFAULT_PRECS, autodiff = nothing,
3-
vjp_autodiff = nothing)
3+
vjp_autodiff = nothing, linesearch = NoLineSearch())
44
scheme_concrete = apply_patch(scheme, (; autodiff, vjp_autodiff))
55
descent = GenericMultiStepDescent(; scheme = scheme_concrete, linsolve, precs)
66
return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = MSS.display_name(scheme),
7-
descent, jacobian_ad = autodiff)
7+
descent, jacobian_ad = autodiff, linesearch, reverse_ad = vjp_autodiff)
88
end

Diff for: src/descent/multistep.jl

+76-58
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,24 @@ struct __PotraPtak3 <: AbstractMultiStepScheme end
2121
const PotraPtak3 = __PotraPtak3()
2222

2323
alg_steps(::__PotraPtak3) = 2
24+
nintermediates(::__PotraPtak3) = 1
2425

2526
@kwdef @concrete struct __SinghSharma4 <: AbstractMultiStepScheme
26-
vjp_autodiff = nothing
27+
jvp_autodiff = nothing
2728
end
2829
const SinghSharma4 = __SinghSharma4()
2930

3031
alg_steps(::__SinghSharma4) = 3
3132

3233
@kwdef @concrete struct __SinghSharma5 <: AbstractMultiStepScheme
33-
vjp_autodiff = nothing
34+
jvp_autodiff = nothing
3435
end
3536
const SinghSharma5 = __SinghSharma5()
3637

3738
alg_steps(::__SinghSharma5) = 3
3839

3940
@kwdef @concrete struct __SinghSharma7 <: AbstractMultiStepScheme
40-
vjp_autodiff = nothing
41+
jvp_autodiff = nothing
4142
end
4243
const SinghSharma7 = __SinghSharma7()
4344

@@ -60,93 +61,110 @@ end
6061

6162
Base.show(io::IO, alg::GenericMultiStepDescent) = print(io, "$(alg.scheme)()")
6263

63-
supports_line_search(::GenericMultiStepDescent) = false
64+
supports_line_search(::GenericMultiStepDescent) = true
6465
supports_trust_region(::GenericMultiStepDescent) = false
6566

66-
@concrete mutable struct GenericMultiStepDescentCache{S, INV} <: AbstractDescentCache
67+
@concrete mutable struct GenericMultiStepDescentCache{S} <: AbstractDescentCache
6768
f
6869
p
6970
δu
7071
δus
71-
extras
72+
u
73+
us
74+
fu
75+
fus
76+
internal_cache
77+
internal_caches
7278
scheme::S
73-
lincache
7479
timer
7580
nf::Int
7681
end
7782

78-
@internal_caches GenericMultiStepDescentCache :lincache
83+
# FIXME: @internal_caches needs to be updated to support tuples and namedtuples
84+
# @internal_caches GenericMultiStepDescentCache :internal_caches
7985

8086
function __reinit_internal!(cache::GenericMultiStepDescentCache, args...; p = cache.p,
8187
kwargs...)
8288
cache.nf = 0
8389
cache.p = p
90+
reset_timer!(cache.timer)
8491
end
8592

86-
function __δu_caches(scheme::MSS.__PotraPtak3, fu, u, ::Val{N}) where {N}
87-
caches = ntuple(N) do i
88-
@bb δu = similar(u)
89-
@bb y = similar(u)
90-
@bb fy = similar(fu)
91-
@bb δy = similar(u)
92-
@bb u_new = similar(u)
93-
(δu, δy, fy, y, u_new)
93+
function __internal_multistep_caches(
94+
scheme::MSS.__PotraPtak3, alg::GenericMultiStepDescent,
95+
prob, args...; shared::Val{N} = Val(1), kwargs...) where {N}
96+
internal_descent = NewtonDescent(; alg.linsolve, alg.precs)
97+
internal_cache = __internal_init(
98+
prob, internal_descent, args...; kwargs..., shared = Val(2))
99+
internal_caches = N 1 ? nothing :
100+
map(2:N) do i
101+
__internal_init(prob, internal_descent, args...; kwargs..., shared = Val(2))
94102
end
95-
return first(caches), (N 1 ? nothing : caches[2:end])
103+
return internal_cache, internal_caches
96104
end
97105

98-
function __internal_init(prob::NonlinearProblem, alg::GenericMultiStepDescent, J, fu, u;
99-
shared::Val{N} = Val(1), pre_inverted::Val{INV} = False, linsolve_kwargs = (;),
106+
function __internal_init(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
107+
alg::GenericMultiStepDescent, J, fu, u; shared::Val{N} = Val(1),
108+
pre_inverted::Val{INV} = False, linsolve_kwargs = (;),
100109
abstol = nothing, reltol = nothing, timer = get_timer_output(),
101110
kwargs...) where {INV, N}
102-
δu, δus = __δu_caches(alg.scheme, fu, u, shared)
103-
INV && return GenericMultiStepDescentCache{true}(prob.f, prob.p, δu, δus,
104-
alg.scheme, nothing, timer, 0)
105-
lincache = LinearSolverCache(alg, alg.linsolve, J, _vec(fu), _vec(u); abstol, reltol,
106-
linsolve_kwargs...)
107-
return GenericMultiStepDescentCache{false}(prob.f, prob.p, δu, δus, alg.scheme,
108-
lincache, timer, 0)
109-
end
110-
111-
function __internal_init(prob::NonlinearLeastSquaresProblem, alg::GenericMultiStepDescent,
112-
J, fu, u; kwargs...)
113-
error("Multi-Step Descent Algorithms for NLLS are not implemented yet.")
111+
@bb δu = similar(u)
112+
δus = N 1 ? nothing : map(2:N) do i
113+
@bb δu_ = similar(u)
114+
end
115+
fu_cache = ntuple(MSS.nintermediates(alg.scheme)) do i
116+
@bb xx = similar(fu)
117+
end
118+
fus_cache = N 1 ? nothing : map(2:N) do i
119+
ntuple(MSS.nintermediates(alg.scheme)) do j
120+
@bb xx = similar(fu)
121+
end
122+
end
123+
u_cache = ntuple(MSS.nintermediates(alg.scheme)) do i
124+
@bb xx = similar(u)
125+
end
126+
us_cache = N 1 ? nothing : map(2:N) do i
127+
ntuple(MSS.nintermediates(alg.scheme)) do j
128+
@bb xx = similar(u)
129+
end
130+
end
131+
internal_cache, internal_caches = __internal_multistep_caches(
132+
alg.scheme, alg, prob, J, fu, u; shared, pre_inverted, linsolve_kwargs,
133+
abstol, reltol, timer, kwargs...)
134+
return GenericMultiStepDescentCache(
135+
prob.f, prob.p, δu, δus, u_cache, us_cache, fu_cache, fus_cache,
136+
internal_cache, internal_caches, alg.scheme, timer, 0)
114137
end
115138

116139
function __internal_solve!(cache::GenericMultiStepDescentCache{MSS.__PotraPtak3, INV}, J,
117140
fu, u, idx::Val = Val(1); skip_solve::Bool = false, new_jacobian::Bool = true,
118141
kwargs...) where {INV}
119-
(u_new, δy, fy, y, δu) = get_du(cache, idx)
120-
skip_solve && return DescentResult(; u = u_new)
121-
122-
@static_timeit cache.timer "linear solve" begin
123-
@static_timeit cache.timer "solve and step 1" begin
124-
if INV
125-
J !== nothing && @bb(δu=J × _vec(fu))
126-
else
127-
δu = cache.lincache(; A = J, b = _vec(fu), kwargs..., linu = _vec(δu),
128-
du = _vec(δu),
129-
reuse_A_if_factorization = !new_jacobian || (idx !== Val(1)))
130-
δu = _restructure(u, δu)
131-
end
132-
@bb @. y = u - δu
133-
end
142+
δu = get_du(cache, idx)
143+
skip_solve && return DescentResult(; δu)
144+
145+
(y,) = get_internal_cache(cache, Val(:u), idx)
146+
(fy,) = get_internal_cache(cache, Val(:fu), idx)
147+
internal_cache = get_internal_cache(cache, Val(:internal_cache), idx)
134148

149+
@static_timeit cache.timer "descent step" begin
150+
result_1 = __internal_solve!(
151+
internal_cache, J, fu, u, Val(1); new_jacobian, kwargs...)
152+
δx = result_1.δu
153+
154+
@bb @. y = u + δx
135155
fy = evaluate_f!!(cache.f, fy, y, cache.p)
136156
cache.nf += 1
137157

138-
@static_timeit cache.timer "solve and step 2" begin
139-
if INV
140-
J !== nothing && @bb(δy=J × _vec(fy))
141-
else
142-
δy = cache.lincache(; A = J, b = _vec(fy), kwargs..., linu = _vec(δy),
143-
du = _vec(δy), reuse_A_if_factorization = true)
144-
δy = _restructure(u, δy)
145-
end
146-
@bb @. u_new = y - δy
147-
end
158+
result_2 = __internal_solve!(
159+
internal_cache, J, fy, y, Val(2); kwargs...)
160+
δy = result_2.δu
161+
162+
@bb @. δu = δx + δy
148163
end
149164

150-
set_du!(cache, (u_new, δy, fy, y, δu), idx)
151-
return DescentResult(; u = u_new)
165+
set_du!(cache, δu, idx)
166+
set_internal_cache!(cache, (y,), Val(:u), idx)
167+
set_internal_cache!(cache, (fy,), Val(:fu), idx)
168+
set_internal_cache!(cache, internal_cache, Val(:internal_cache), idx)
169+
return DescentResult(; δu)
152170
end

0 commit comments

Comments
 (0)