Skip to content

Commit

Permalink
Merge pull request #360 from CliMA/kp/itime
Browse files Browse the repository at this point in the history
Add support for ITime in IMEXAlgorithms and SSPKnoth
  • Loading branch information
ph-kev authored Feb 4, 2025
2 parents 207704b + 0794890 commit ae9a03a
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 13 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ v0.8.2
- ![][badge-💥breaking] If saveat is a number, then it does not automatically expand to `tspan[1]:saveat:tspan[2]`. To fix this, update
`saveat`, which is a keyword in the integrator, to be an array. For example, if `saveat` is a scalar, replace it with
`[tspan[1]:saveat:tspan[2]..., tspan[2]]` to achieve the same behavior as before.
- IMEXAlgorithms and SSPKnoth are compatible with ITime. See ClimaUtilities for more information about ITime.

v0.7.18
-------
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ClimaTimeSteppers"
uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
authors = ["Climate Modeling Alliance"]
version = "0.8.1"
version = "0.8.2"

[deps]
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
Expand Down
1 change: 1 addition & 0 deletions src/ClimaTimeSteppers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ include("solvers/rosenbrock.jl")

include("Callbacks.jl")

include("arbitrary_number_types.jl")

benchmark_step(integrator, device) =
@warn "Must load CUDA, BenchmarkTools, OrderedCollections, StatsBase, PrettyTables to trigger the ClimaTimeSteppersBenchmarkToolsExt extension"
Expand Down
19 changes: 19 additions & 0 deletions src/arbitrary_number_types.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
SciMLBase.allows_arbitrary_number_types(alg::T)
where {T <: ClimaTimeSteppers.RosenbrockAlgorithm}
Return `true`. Enable RosenbrockAlgorithms to run with `ClimaUtilities.ITime`.
"""
function SciMLBase.allows_arbitrary_number_types(alg::T) where {T <: ClimaTimeSteppers.RosenbrockAlgorithm}
true
end

"""
SciMLBase.allows_arbitrary_number_types(alg::T)
where {T <: ClimaTimeSteppers.IMEXAlgorithm}
Return `true`. Enable IMEXAlgorithms to run with `ClimaUtilities.ITime`.
"""
function SciMLBase.allows_arbitrary_number_types(alg::T) where {T <: ClimaTimeSteppers.IMEXAlgorithm}
true
end
14 changes: 9 additions & 5 deletions src/integrators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ end

# helper function for setting up min/max heaps for tstops and saveat
function tstops_and_saveat_heaps(t0, tf, tstops, saveat = [])
FT = typeof(tf)
# We promote to a common type to ensure that t0 and tf have the same type
FT = typeof(first(promote(t0, tf)))
ordering = tf > t0 ? DataStructures.FasterForward : DataStructures.FasterReverse

# ensure that tstops includes tf and only has values ahead of t0
Expand All @@ -81,7 +82,7 @@ function tstops_and_saveat_heaps(t0, tf, tstops, saveat = [])
return tstops, saveat
end

compute_tdir(ts) = ts[1] > ts[end] ? sign(ts[end] - ts[1]) : eltype(ts)(1)
compute_tdir(ts) = ts[1] > ts[end] ? sign(ts[end] - ts[1]) : oneunit(ts[1])

# called by DiffEqBase.init and DiffEqBase.solve
function DiffEqBase.__init(
Expand All @@ -102,8 +103,10 @@ function DiffEqBase.__init(
)
(; u0, p) = prob
t0, tf = prob.tspan
t0, tf, dt = promote(t0, tf, dt)

dt > zero(dt) || error("dt must be positive")
# We need zero(oneunit()) because there's no zerounit
dt > zero(oneunit(dt)) || error("dt must be positive")
_dt = dt
dt = tf > t0 ? dt : -dt

Expand Down Expand Up @@ -243,8 +246,9 @@ function __step!(integrator)
# is taken from OrdinaryDiffEq.jl
t_plus_dt = integrator.t + integrator.dt
t_unit = oneunit(integrator.t)
max_t_error = 100 * eps(float(integrator.t / t_unit)) * t_unit
integrator.t = !isempty(tstops) && abs(first(tstops) - t_plus_dt) < max_t_error ? first(tstops) : t_plus_dt
max_t_error = 100 * eps(float(integrator.t / t_unit)) * float(t_unit)
integrator.t =
!isempty(tstops) && abs(float(first(tstops)) - float(t_plus_dt)) < max_t_error ? first(tstops) : t_plus_dt

# apply callbacks
discrete_callbacks = integrator.callback.discrete_callbacks
Expand Down
2 changes: 1 addition & 1 deletion src/solvers/imex_ark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ end

t_exp = t + dt * c_exp[i]
t_imp = t + dt * c_imp[i]
dtγ = dt * a_imp[i, i]
dtγ = float(dt) * a_imp[i, i]

if has_T_lim(f) # Update based on limited tendencies from previous stages
assign_fused_increment!(U, u, dt, a_exp, T_lim, Val(i))
Expand Down
8 changes: 4 additions & 4 deletions src/solvers/rosenbrock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ function step_u!(int, cache::RosenbrockCache{Nstages}) where {Nstages}

# TODO: This is only valid when Γ[i, i] is constant, otherwise we have to
# move this in the for loop
@inbounds dtγ = dt * Γ[1, 1]
@inbounds dtγ = float(dt) * Γ[1, 1]

if !isnothing(T_imp!)
Wfact! = int.sol.prob.f.T_imp!.Wfact
Expand Down Expand Up @@ -175,14 +175,14 @@ function step_u!(int, cache::RosenbrockCache{Nstages}) where {Nstages}
end

if !isnothing(tgrad!)
fU .+= γi .* dt .* ∂Y∂t
fU .+= γi .* float(dt) .* ∂Y∂t
end

for j in 1:(i - 1)
fU .+= (C[i, j] / dt) .* k[j]
fU .+= (C[i, j] / float(dt)) .* k[j]
end

fU .*= -dtγ
fU .*= -float(dtγ)

if !isnothing(T_imp!)
if W isa Matrix
Expand Down
4 changes: 2 additions & 2 deletions src/utilities/fused_increment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ In the edge case (coeffs are zero, `j` range is empty),
this lowers to `nothing` (no-op)
"""
@inline function fused_increment!(u, dt, sc, tend, v)
bc = fused_increment(u, dt, sc, tend, v)
bc = fused_increment(u, float(dt), sc, tend, v)
if bc isa Base.Broadcast.Broadcasted # Only material if not trivial assignment
Base.Broadcast.materialize!(u, bc)
end
Expand All @@ -142,7 +142,7 @@ this lowers to
`@. U = u`
"""
@inline function assign_fused_increment!(U, u, dt, sc, tend, v)
bc = fused_increment(u, dt, sc, tend, v)
bc = fused_increment(u, float(dt), sc, tend, v)
Base.Broadcast.materialize!(U, bc)
return nothing
end

2 comments on commit ae9a03a

@ph-kev
Copy link
Member Author

@ph-kev ph-kev commented on ae9a03a Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

Make IMEXAlgorithms and SSPKnoth compatible with ITime.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/124281

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.2 -m "<description of version>" ae9a03a9208405636320ba2762c6941c45b3697b
git push origin v0.8.2

Please sign in to comment.