Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create CheckInit and add tagging of initializations to callbacks #783

Merged
merged 4 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,11 @@ $(TYPEDEF)
"""
struct NoInit <: DAEInitializationAlgorithm end

"""
$(TYPEDEF)
"""
struct CheckInit <: DAEInitializationAlgorithm end

# PDE Discretizations

"""
Expand Down
109 changes: 81 additions & 28 deletions src/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ ContinuousCallback(condition, affect!, affect_neg!;
rootfind = LeftRootFind,
save_positions = (true, true),
interp_points = 10,
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100)
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
initializealg = nothing)
```

```julia
Expand All @@ -34,7 +35,8 @@ ContinuousCallback(condition, affect!;
save_positions = (true, true),
affect_neg! = affect!,
interp_points = 10,
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100)
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
initializealg = nothing)
```

Contains a single callback whose `condition` is a continuous function. The callback is triggered when this function evaluates to 0.
Expand Down Expand Up @@ -91,8 +93,26 @@ Contains a single callback whose `condition` is a continuous function. The callb
- `repeat_nudge = 1//100`: This is used to set the next testing point after a
previously found zero. Defaults to 1//100, which means after a callback, the next
sign check will take place at t + dt*1//100 instead of at t to avoid repeats.
- `initializealg = nothing`: In the context of a DAE, this is the algorithm that is used
to run initialization after the effect. The default of `nothing` defers to the initialization
algorithm provided in the `solve`.

!!! warn

The effect of using a callback with a DAE needs to be done with care because the solution
`u` needs to satisfy the algebraic constraints before taking the next step. For this reason,
a consistent initialization calculation must be run after running the callback. If the
chosen initialization alg is `BrownBasicInit()` (the default for `solve`), then the initialization
will change the algebraic variables to satisfy the conditions. Thus if `x` is an algebraic
variable and the callback performs `x+=1`, the initialization may "revert" the change to
satisfy the constraints. This behavior can be removed by setting `initializealg = CheckInit()`,
which simply checks that the state `u` is consistent, but requires that the result of the
`affect!` satisfies the constraints (or else errors). It is not recommended that `NoInit()` is
used as that will lead to an unstable step following initialization. This warning can be
ignored for non-DAE ODEs.
"""
struct ContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, I, R} <: AbstractContinuousCallback
struct ContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <:
AbstractContinuousCallback
condition::F1
affect!::F2
affect_neg!::F3
Expand All @@ -106,19 +126,21 @@ struct ContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, I, R} <: AbstractContin
abstol::T
reltol::T2
repeat_nudge::T3
initializealg::T4
function ContinuousCallback(condition::F1, affect!::F2, affect_neg!::F3,
initialize::F4, finalize::F5, idxs::I, rootfind,
interp_points, save_positions, dtrelax::R, abstol::T,
reltol::T2,
repeat_nudge::T3) where {F1, F2, F3, F4, F5, T, T2, T3, I, R
repeat_nudge::T3,
initializealg::T4) where {F1, F2, F3, F4, F5, T, T2, T3, T4, I, R
}
_condition = prepare_function(condition)
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, I, R}(_condition,
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R}(_condition,
affect!, affect_neg!,
initialize, finalize, idxs, rootfind,
interp_points,
BitArray(collect(save_positions)),
dtrelax, abstol, reltol, repeat_nudge)
dtrelax, abstol, reltol, repeat_nudge, initializealg)
end
end

Expand All @@ -131,12 +153,13 @@ function ContinuousCallback(condition, affect!, affect_neg!;
interp_points = 10,
dtrelax = 1,
abstol = 10eps(), reltol = 0,
repeat_nudge = 1 // 100)
repeat_nudge = 1 // 100,
initializealg = nothing)
ContinuousCallback(condition, affect!, affect_neg!, initialize, finalize,
idxs,
rootfind, interp_points,
save_positions,
dtrelax, abstol, reltol, repeat_nudge)
dtrelax, abstol, reltol, repeat_nudge, initializealg)
end

function ContinuousCallback(condition, affect!;
Expand All @@ -148,11 +171,12 @@ function ContinuousCallback(condition, affect!;
affect_neg! = affect!,
interp_points = 10,
dtrelax = 1,
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100)
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
initializealg = nothing)
ContinuousCallback(condition, affect!, affect_neg!, initialize, finalize, idxs,
rootfind, interp_points,
collect(save_positions),
dtrelax, abstol, reltol, repeat_nudge)
dtrelax, abstol, reltol, repeat_nudge, initializealg)
end

"""
Expand All @@ -164,7 +188,8 @@ VectorContinuousCallback(condition, affect!, affect_neg!, len;
rootfind = LeftRootFind,
save_positions = (true, true),
interp_points = 10,
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100)
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
initializealg = nothing)
```

```julia
Expand All @@ -176,7 +201,8 @@ VectorContinuousCallback(condition, affect!, len;
save_positions = (true, true),
affect_neg! = affect!,
interp_points = 10,
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100)
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
initializealg = nothing)
```

This is also a subtype of `AbstractContinuousCallback`. `CallbackSet` is not feasible when you have many callbacks,
Expand All @@ -194,7 +220,7 @@ multiple events.

Rest of the arguments have the same meaning as in [`ContinuousCallback`](@ref).
"""
struct VectorContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, I, R} <:
struct VectorContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, T4, I, R} <:
AbstractContinuousCallback
condition::F1
affect!::F2
Expand All @@ -210,20 +236,22 @@ struct VectorContinuousCallback{F1, F2, F3, F4, F5, T, T2, T3, I, R} <:
abstol::T
reltol::T2
repeat_nudge::T3
initializealg::T4
function VectorContinuousCallback(
condition::F1, affect!::F2, affect_neg!::F3, len::Int,
initialize::F4, finalize::F5, idxs::I, rootfind,
interp_points, save_positions, dtrelax::R,
abstol::T, reltol::T2,
repeat_nudge::T3) where {F1, F2, F3, F4, F5, T, T2,
T3, I, R}
repeat_nudge::T3,
initializealg::T4) where {F1, F2, F3, F4, F5, T, T2,
T3, T4, I, R}
_condition = prepare_function(condition)
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, I, R}(_condition,
new{typeof(_condition), F2, F3, F4, F5, T, T2, T3, T4, I, R}(_condition,
affect!, affect_neg!, len,
initialize, finalize, idxs, rootfind,
interp_points,
BitArray(collect(save_positions)),
dtrelax, abstol, reltol, repeat_nudge)
dtrelax, abstol, reltol, repeat_nudge, initializealg)
end
end

Expand All @@ -235,13 +263,14 @@ function VectorContinuousCallback(condition, affect!, affect_neg!, len;
save_positions = (true, true),
interp_points = 10,
dtrelax = 1,
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100)
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
initializealg = nothing)
VectorContinuousCallback(condition, affect!, affect_neg!, len,
initialize, finalize,
idxs,
rootfind, interp_points,
save_positions, dtrelax,
abstol, reltol, repeat_nudge)
abstol, reltol, repeat_nudge, initializealg)
end

function VectorContinuousCallback(condition, affect!, len;
Expand All @@ -253,20 +282,22 @@ function VectorContinuousCallback(condition, affect!, len;
affect_neg! = affect!,
interp_points = 10,
dtrelax = 1,
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100)
abstol = 10eps(), reltol = 0, repeat_nudge = 1 // 100,
initializealg = nothing)
VectorContinuousCallback(condition, affect!, affect_neg!, len, initialize, finalize,
idxs,
rootfind, interp_points,
collect(save_positions),
dtrelax, abstol, reltol, repeat_nudge)
dtrelax, abstol, reltol, repeat_nudge, initializealg)
end

"""
```julia
DiscreteCallback(condition, affect!;
initialize = INITIALIZE_DEFAULT,
finalize = FINALIZE_DEFAULT,
save_positions = (true, true))
save_positions = (true, true),
initializealg = nothing)
```

# Arguments
Expand All @@ -291,26 +322,48 @@ DiscreteCallback(condition, affect!;
- `finalize`: This is a function `(c,u,t,integrator)` which can be used to finalize
the state of the callback `c`. It should can the argument `c` and the return is
ignored.
- `initializealg = nothing`: In the context of a DAE, this is the algorithm that is used
to run initialization after the effect. The default of `nothing` defers to the initialization
algorithm provided in the `solve`.

!!! warn

The effect of using a callback with a DAE needs to be done with care because the solution
`u` needs to satisfy the algebraic constraints before taking the next step. For this reason,
a consistent initialization calculation must be run after running the callback. If the
chosen initialization alg is `BrownBasicInit()` (the default for `solve`), then the initialization
will change the algebraic variables to satisfy the conditions. Thus if `x` is an algebraic
variable and the callback performs `x+=1`, the initialization may "revert" the change to
satisfy the constraints. This behavior can be removed by setting `initializealg = CheckInit()`,
which simply checks that the state `u` is consistent, but requires that the result of the
`affect!` satisfies the constraints (or else errors). It is not recommended that `NoInit()` is
used as that will lead to an unstable step following initialization. This warning can be
ignored for non-DAE ODEs.
"""
struct DiscreteCallback{F1, F2, F3, F4} <: AbstractDiscreteCallback
struct DiscreteCallback{F1, F2, F3, F4, F5} <: AbstractDiscreteCallback
condition::F1
affect!::F2
initialize::F3
finalize::F4
save_positions::BitArray{1}
initializealg::F5
function DiscreteCallback(condition::F1, affect!::F2,
initialize::F3, finalize::F4,
save_positions) where {F1, F2, F3, F4}
save_positions,
initializealg::F5) where {F1, F2, F3, F4, F5}
_condition = prepare_function(condition)
new{typeof(_condition), F2, F3, F4}(_condition,
new{typeof(_condition), F2, F3, F4, F5}(_condition,
affect!, initialize, finalize,
BitArray(collect(save_positions)))
BitArray(collect(save_positions)),
initializealg)
end
end
function DiscreteCallback(condition, affect!;
initialize = INITIALIZE_DEFAULT, finalize = FINALIZE_DEFAULT,
save_positions = (true, true))
DiscreteCallback(condition, affect!, initialize, finalize, save_positions)
save_positions = (true, true),
initializealg = nothing)
DiscreteCallback(
condition, affect!, initialize, finalize, save_positions, initializealg)
end

"""
Expand Down
9 changes: 4 additions & 5 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,12 @@ function calculate_ensemble_errors(u; elapsedTime = 0.0, converged = false,
res = norm(m_final - m_final_analytic)
weak_errors[:weak_final] = res
if weak_timeseries_errors

if analyticvoa
ts_weak_errors = [mean([u[j].u[i] - u[j].u_analytic.u[i] for j in 1:length(u)])
for i in 1:length(u[1])]
for i in 1:length(u[1])]
else
ts_weak_errors = [mean([u[j].u[i] - u[j].u_analytic[i] for j in 1:length(u)])
for i in 1:length(u[1])]
for i in 1:length(u[1])]
end
ts_l2_errors = [sqrt.(sum(abs2, err) / length(err)) for err in ts_weak_errors]
l2_tmp = sqrt(sum(abs2, ts_l2_errors) / length(ts_l2_errors))
Expand All @@ -128,8 +127,8 @@ function calculate_ensemble_errors(u; elapsedTime = 0.0, converged = false,
if weak_dense_errors
densetimes = collect(range(u[1].t[1], stop = u[1].t[end], length = 100))
u_analytic = [[sol.prob.f.analytic(sol.prob.u0, sol.prob.p, densetimes[i],
sol.W(densetimes[i])[1])
for i in eachindex(densetimes)] for sol in u]
sol.W(densetimes[i])[1])
for i in eachindex(densetimes)] for sol in u]

udense = [u[j](densetimes) for j in 1:length(u)]
dense_weak_errors = [mean([udense[j].u[i] - u_analytic[j][i] for j in 1:length(u)])
Expand Down
Loading