Skip to content

Commit ceeadcb

Browse files
committed
materialize the multi-step scheme
1 parent ff82fd9 commit ceeadcb

File tree

8 files changed

+67
-27
lines changed

8 files changed

+67
-27
lines changed

Diff for: Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "3.6.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
8+
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
89
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
910
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1011
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
@@ -55,12 +56,13 @@ NonlinearSolveZygoteExt = "Zygote"
5556

5657
[compat]
5758
ADTypes = "0.2.6"
59+
Accessors = "0.1"
5860
Aqua = "0.8"
5961
ArrayInterface = "7.7"
6062
BandedMatrices = "1.4"
6163
BenchmarkTools = "1.4"
62-
ConcreteStructs = "0.2.3"
6364
CUDA = "5.1"
65+
ConcreteStructs = "0.2.3"
6466
DiffEqBase = "6.146.0"
6567
Enzyme = "0.11.11"
6668
FastBroadcast = "0.2.8"

Diff for: docs/src/basics/faq.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ differentiate the function based on the input types. However, this function has
7272
`xx = [1.0, 2.0, 3.0, 4.0]` followed by a `xx[1] = var[1] - v_true[1]` where `var` might
7373
be a Dual number. This causes the error. To fix it:
7474

75-
1. Specify the `autodiff` to be `AutoFiniteDiff`
75+
1. Specify the `autodiff` to be `AutoFiniteDiff`
7676

7777
```@example dual_error_faq
7878
sol = solve(prob_oop, LevenbergMarquardt(; autodiff = AutoFiniteDiff()); maxiters = 10000,
@@ -81,7 +81,7 @@ sol = solve(prob_oop, LevenbergMarquardt(; autodiff = AutoFiniteDiff()); maxiter
8181

8282
This worked but, Finite Differencing is not the recommended approach in any scenario.
8383

84-
2. Rewrite the function to use
84+
2. Rewrite the function to use
8585
[PreallocationTools.jl](https://github.com/SciML/PreallocationTools.jl) or write it as
8686

8787
```@example dual_error_faq

Diff for: docs/src/basics/sparsity_detection.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ prob = NonlinearProblem(
3434
If the `colorvec` is not provided, then it is computed on demand.
3535

3636
!!! note
37-
37+
3838
One thing to be careful about in this case is that `colorvec` is dependent on the
3939
autodiff backend used. Forward Mode and Finite Differencing will assume that the
4040
colorvec is the column colorvec, while Reverse Mode will assume that the colorvec is the
@@ -76,7 +76,7 @@ loaded, we default to using `SymbolicsSparsityDetection()`, else we default to u
7676
options if those are provided.
7777

7878
!!! warning
79-
79+
8080
If you provide a non-sparse AD, and provide a `sparsity` or `jac_prototype` then
8181
we will use dense AD. This is because, if you provide a specific AD type, we assume
8282
that you know what you are doing and want to override the default choice of `nothing`.

Diff for: docs/src/tutorials/large_systems.md

+9-9
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22

33
This tutorial is for getting into the extra features of using NonlinearSolve.jl. Solving
44
ill-conditioned nonlinear systems requires specializing the linear solver on properties of
5-
the Jacobian in order to cut down on the ``\mathcal{O}(n^3)`` linear solve and the
6-
``\mathcal{O}(n^2)`` back-solves. This tutorial is designed to explain the advanced usage of
5+
the Jacobian in order to cut down on the `\mathcal{O}(n^3)` linear solve and the
6+
`\mathcal{O}(n^2)` back-solves. This tutorial is designed to explain the advanced usage of
77
NonlinearSolve.jl by solving the steady state stiff Brusselator partial differential
88
equation (BRUSS) using NonlinearSolve.jl.
99

1010
## Definition of the Brusselator Equation
1111

1212
!!! note
13-
13+
1414
Feel free to skip this section: it simply defines the example problem.
1515

1616
The Brusselator PDE is defined as follows:
@@ -118,11 +118,11 @@ However, if you know the sparsity of your problem, then you can pass a different
118118
type. For example, a `SparseMatrixCSC` will give a sparse matrix. Other sparse matrix types
119119
include:
120120

121-
- Bidiagonal
122-
- Tridiagonal
123-
- SymTridiagonal
124-
- BandedMatrix ([BandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BandedMatrices.jl))
125-
- BlockBandedMatrix ([BlockBandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BlockBandedMatrices.jl))
121+
- Bidiagonal
122+
- Tridiagonal
123+
- SymTridiagonal
124+
- BandedMatrix ([BandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BandedMatrices.jl))
125+
- BlockBandedMatrix ([BlockBandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BlockBandedMatrices.jl))
126126

127127
## Approximate Sparsity Detection & Sparse Jacobians
128128

@@ -213,7 +213,7 @@ choices, see the
213213
`linsolve` choices are any valid [LinearSolve.jl](https://linearsolve.sciml.ai/dev/) solver.
214214

215215
!!! note
216-
216+
217217
Switching to a Krylov linear solver will automatically change the nonlinear problem
218218
solver into Jacobian-free mode, dramatically reducing the memory required. This can be
219219
overridden by adding `concrete_jac=true` to the algorithm.

Diff for: src/NonlinearSolve.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ import Reexport: @reexport
88
import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload
99

1010
@recompile_invalidations begin
11-
using ADTypes, ConcreteStructs, DiffEqBase, FastBroadcast, FastClosures, LazyArrays,
12-
LineSearches, LinearAlgebra, LinearSolve, MaybeInplace, Preferences, Printf,
13-
SciMLBase, SimpleNonlinearSolve, SparseArrays, SparseDiffTools
11+
using Accessors, ADTypes, ConcreteStructs, DiffEqBase, FastBroadcast, FastClosures,
12+
LazyArrays, LineSearches, LinearAlgebra, LinearSolve, MaybeInplace, Preferences,
13+
Printf, SciMLBase, SimpleNonlinearSolve, SparseArrays, SparseDiffTools
1414

1515
import ArrayInterface: undefmatrix, can_setindex, restructure, fast_scalar_indexing
1616
import DiffEqBase: AbstractNonlinearTerminationMode,
@@ -142,7 +142,7 @@ end
142142

143143
# Core Algorithms
144144
export NewtonRaphson, PseudoTransient, Klement, Broyden, LimitedMemoryBroyden, DFSane,
145-
MultiStepNonlinearSolver
145+
MultiStepNonlinearSolver
146146
export GaussNewton, LevenbergMarquardt, TrustRegion
147147
export NonlinearSolvePolyAlgorithm,
148148
RobustMultiNewton, FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg
@@ -156,7 +156,7 @@ export GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm, Genera
156156

157157
# Descent Algorithms
158158
export NewtonDescent, SteepestDescent, Dogleg, DampedNewtonDescent,
159-
GeodesicAcceleration, GenericMultiStepDescent
159+
GeodesicAcceleration, GenericMultiStepDescent
160160
## Multistep Algorithms
161161
export MultiStepSchemes
162162

Diff for: src/algorithms/multistep.jl

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
function MultiStepNonlinearSolver(; concrete_jac = nothing, linsolve = nothing,
2-
scheme = MSS.PotraPtak3, precs = DEFAULT_PRECS, autodiff = nothing)
3-
descent = GenericMultiStepDescent(; scheme, linsolve, precs)
4-
# TODO: Use the scheme as the name
5-
return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = :MultiStepNonlinearSolver,
2+
scheme = MSS.PotraPtak3, precs = DEFAULT_PRECS, autodiff = nothing,
3+
vjp_autodiff = nothing)
4+
scheme_concrete = apply_patch(scheme, (; autodiff, vjp_autodiff))
5+
descent = GenericMultiStepDescent(; scheme = scheme_concrete, linsolve, precs)
6+
return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = MSS.display_name(scheme),
67
descent, jacobian_ad = autodiff)
78
end

Diff for: src/descent/multistep.jl

+22-4
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,47 @@ typically the last names of the authors of the paper that introduced the method.
77
"""
88
module MultiStepSchemes
99

10+
using ConcreteStructs
11+
1012
abstract type AbstractMultiStepScheme end
1113

1214
function Base.show(io::IO, mss::AbstractMultiStepScheme)
1315
print(io, "MultiStepSchemes.$(string(nameof(typeof(mss)))[3:end])")
1416
end
1517

18+
alg_steps(::Type{T}) where {T <: AbstractMultiStepScheme} = alg_steps(T())
19+
1620
struct __PotraPtak3 <: AbstractMultiStepScheme end
1721
const PotraPtak3 = __PotraPtak3()
1822

19-
alg_steps(::__PotraPtak3) = 1
23+
alg_steps(::__PotraPtak3) = 2
2024

21-
struct __SinghSharma4 <: AbstractMultiStepScheme end
25+
@kwdef @concrete struct __SinghSharma4 <: AbstractMultiStepScheme
26+
vjp_autodiff = nothing
27+
end
2228
const SinghSharma4 = __SinghSharma4()
2329

2430
alg_steps(::__SinghSharma4) = 3
2531

26-
struct __SinghSharma5 <: AbstractMultiStepScheme end
32+
@kwdef @concrete struct __SinghSharma5 <: AbstractMultiStepScheme
33+
vjp_autodiff = nothing
34+
end
2735
const SinghSharma5 = __SinghSharma5()
2836

2937
alg_steps(::__SinghSharma5) = 3
3038

31-
struct __SinghSharma7 <: AbstractMultiStepScheme end
39+
@kwdef @concrete struct __SinghSharma7 <: AbstractMultiStepScheme
40+
vjp_autodiff = nothing
41+
end
3242
const SinghSharma7 = __SinghSharma7()
3343

3444
alg_steps(::__SinghSharma7) = 4
3545

46+
@generated function display_name(alg::T) where {T <: AbstractMultiStepScheme}
47+
res = Symbol(first(split(last(split(string(T), ".")), "{"; limit = 2))[3:end])
48+
return :($(Meta.quot(res)))
49+
end
50+
3651
end
3752

3853
const MSS = MultiStepSchemes
@@ -43,6 +58,8 @@ const MSS = MultiStepSchemes
4358
precs = DEFAULT_PRECS
4459
end
4560

61+
Base.show(io::IO, alg::GenericMultiStepDescent) = print(io, "$(alg.scheme)()")
62+
4663
supports_line_search(::GenericMultiStepDescent) = false
4764
supports_trust_region(::GenericMultiStepDescent) = false
4865

@@ -51,6 +68,7 @@ supports_trust_region(::GenericMultiStepDescent) = false
5168
p
5269
δu
5370
δus
71+
extras
5472
scheme::S
5573
lincache
5674
timer

Diff for: src/utils.jl

+19
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,22 @@ Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the i
158158
"""
159159
@inline pickchunksize(x) = pickchunksize(length(x))
160160
@inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x)
161+
162+
"""
163+
apply_patch(scheme, patch::NamedTuple{names})
164+
165+
Applies the patch to the scheme, returning the new scheme. If some of the `names` are not,
166+
present in the scheme, they are ignored.
167+
"""
168+
@generated function apply_patch(scheme, patch::NamedTuple{names}) where {names}
169+
exprs = []
170+
for name in names
171+
hasfield(scheme, name) || continue
172+
push!(exprs, quote
173+
lens = PropertyLens{$(Meta.quot(name))}()
174+
return set(scheme, lens, getfield(patch, $(Meta.quot(name))))
175+
end)
176+
end
177+
push!(exprs, :(return scheme))
178+
return Expr(:block, exprs...)
179+
end

0 commit comments

Comments
 (0)