Skip to content

Commit

Permalink
Make ARS343 and SSPKnoth compatible with ITime
Browse files Browse the repository at this point in the history
  • Loading branch information
ph-kev committed Feb 3, 2025
1 parent ed01902 commit 923dd11
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 11 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Main
-------
v0.9.0
- ![][badge-💥breaking] If saveat is a number, then it does not automatically expand to `tspan[1]:saveat:tspan[2]`.
- ARS343 and SSPKnoth are compatible with ITime. See ClimaUtilities for more information about ITime.

v0.7.18
-------
Expand Down
1 change: 1 addition & 0 deletions src/ClimaTimeSteppers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ include("solvers/rosenbrock.jl")

include("Callbacks.jl")

include("arbitrary_number_types.jl")

benchmark_step(integrator, device) =
@warn "Must load CUDA, BenchmarkTools, OrderedCollections, StatsBase, PrettyTables to trigger the ClimaTimeSteppersBenchmarkToolsExt extension"
Expand Down
19 changes: 19 additions & 0 deletions src/arbitrary_number_types.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
SciMLBase.allows_arbitrary_number_types(alg::T)
where {T <: ClimaTimeSteppers.RosenbrockAlgorithm}
Return `true`. Enable SSPKnoth to run with `ClimaUtilities.ITime`.
"""
function SciMLBase.allows_arbitrary_number_types(alg::T) where {T <: ClimaTimeSteppers.RosenbrockAlgorithm}
true
end

"""
SciMLBase.allows_arbitrary_number_types(alg::T)
where {T <: ClimaTimeSteppers.IMEXAlgorithm}
Return `true`. Enable ARS343 to run with `ClimaUtilities.ITime`.
"""
function SciMLBase.allows_arbitrary_number_types(alg::T) where {T <: ClimaTimeSteppers.IMEXAlgorithm}
true
end
9 changes: 5 additions & 4 deletions src/integrators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ function tstops_and_saveat_heaps(t0, tf, tstops, saveat = [])
return tstops, saveat
end

compute_tdir(ts) = ts[1] > ts[end] ? sign(ts[end] - ts[1]) : eltype(ts)(1)
compute_tdir(ts) = ts[1] > ts[end] ? sign(ts[end] - ts[1]) : oneunit(ts[1])

# called by DiffEqBase.init and DiffEqBase.solve
function DiffEqBase.__init(
Expand All @@ -102,8 +102,9 @@ function DiffEqBase.__init(
)
(; u0, p) = prob
t0, tf = prob.tspan
t0, tf, dt = promote(t0, tf, dt)

dt > zero(dt) || error("dt must be positive")
dt > zero(oneunit(dt)) || error("dt must be positive")
_dt = dt
dt = tf > t0 ? dt : -dt

Expand Down Expand Up @@ -243,8 +244,8 @@ function __step!(integrator)
# is taken from OrdinaryDiffEq.jl
t_plus_dt = integrator.t + integrator.dt
t_unit = oneunit(integrator.t)
max_t_error = 100 * eps(float(integrator.t / t_unit)) * t_unit
integrator.t = !isempty(tstops) && abs(first(tstops) - t_plus_dt) < max_t_error ? first(tstops) : t_plus_dt
max_t_error = 100 * eps(float(integrator.t / t_unit)) * float(t_unit)
integrator.t = !isempty(tstops) && abs(float(first(tstops)) - float(t_plus_dt)) < max_t_error ? first(tstops) : t_plus_dt

# apply callbacks
discrete_callbacks = integrator.callback.discrete_callbacks
Expand Down
6 changes: 5 additions & 1 deletion src/solvers/imex_ark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ end

t_exp = t + dt * c_exp[i]
t_imp = t + dt * c_imp[i]
dtγ = dt * a_imp[i, i]
dtγ = float(dt) * a_imp[i, i]

if has_T_lim(f) # Update based on limited tendencies from previous stages
assign_fused_increment!(U, u, dt, a_exp, T_lim, Val(i))
Expand Down Expand Up @@ -135,6 +135,10 @@ end
T_imp!(residual, U′, p, t_imp)
@. residual = temp + dtγ * residual - U′
end
implicit_equation_jacobian! = (jacobian, Ui) -> begin
T_imp!.Wfact(jacobian, Ui, p, dtγ, t_imp)
end
implicit_equation_cache! = Ui -> cache_imp!(Ui, p, t_imp)
solve_newton!(
newtons_method,
newtons_method_cache,
Expand Down
8 changes: 4 additions & 4 deletions src/solvers/rosenbrock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ function step_u!(int, cache::RosenbrockCache{Nstages}) where {Nstages}

# TODO: This is only valid when Γ[i, i] is constant, otherwise we have to
# move this in the for loop
@inbounds dtγ = dt * Γ[1, 1]
@inbounds dtγ = float(dt) * Γ[1, 1]

if !isnothing(T_imp!)
Wfact! = int.sol.prob.f.T_imp!.Wfact
Expand Down Expand Up @@ -175,14 +175,14 @@ function step_u!(int, cache::RosenbrockCache{Nstages}) where {Nstages}
end

if !isnothing(tgrad!)
fU .+= γi .* dt .* ∂Y∂t
fU .+= γi .* float(dt) .* ∂Y∂t
end

for j in 1:(i - 1)
fU .+= (C[i, j] / dt) .* k[j]
fU .+= (C[i, j] / float(dt)) .* k[j]
end

fU .*= -dtγ
fU .*= -float(dtγ)

if !isnothing(T_imp!)
if W isa Matrix
Expand Down
4 changes: 2 additions & 2 deletions src/utilities/fused_increment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ In the edge case (coeffs are zero, `j` range is empty),
this lowers to `nothing` (no-op)
"""
@inline function fused_increment!(u, dt, sc, tend, v)
bc = fused_increment(u, dt, sc, tend, v)
bc = fused_increment(u, float(dt), sc, tend, v)
if bc isa Base.Broadcast.Broadcasted # Only material if not trivial assignment
Base.Broadcast.materialize!(u, bc)
end
Expand All @@ -142,7 +142,7 @@ this lowers to
`@. U = u`
"""
@inline function assign_fused_increment!(U, u, dt, sc, tend, v)
bc = fused_increment(u, dt, sc, tend, v)
bc = fused_increment(u, float(dt), sc, tend, v)
Base.Broadcast.materialize!(U, bc)
return nothing
end

0 comments on commit 923dd11

Please sign in to comment.