-
-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type inference on solution output fails for DAE with callback #2594
Comments
Turns out that, on master at least, making the callback a const constcallback = ContinuousCallback((u,_,__)->u[1]-1, terminate!)
function eval_prm_ccb(pr)
newprob = ODEProblem{true}(dae!, [2.0, -2.0], (0.0, 1.0), (pr,))
sol = solve(newprob, Rosenbrock23(); callback=constcallback)
evalsol(sol)
end
while
|
Oh right. That is expected. Otherwise this is just normal global variable type instability. |
In my code base (which has some more intricacies), I'm still seeing something that isn't type inferrable. I'll make another issue later if I manage to identify where the problem actually is, but in the meantime I can probably get away with a type assertion in my squared error calculation. |
Are you using save_idxs? |
No, not currently (didn't know that was an option). My problem only has two equations, so I haven't been worried on that front, but I have a parameter struct full of Unitful types and the endpoint in time is dependent on the system dynamics (hence a callback that terminates intrgration when one of my two variables approaches zero) |
Oh Unitful 😅 If you don't use unitful do you still have issues? |
I can replicate my problem without Unitful; here is an example which resembles my code base, without using Unitful: using OrdinaryDiffEqRosenbrock
using LinearAlgebra: Diagonal
using Accessors
# using OrdinaryDiffEqNonlinearSolve
struct Coeffs{T1, T2}
p1::T1
p2::T2
end
function Base.getindex(c::Coeffs, i::Int)
if i == 1
return c.p1
elseif i == 2
return c.p2
else
throw(ArgumentError("Index out of bounds"))
end
end
function dae!(du, u, p, t)
du[1] = -u[1]*p[1]
du[2] = u[2] + u[1]
nothing
end
const dae_fc = ODEFunction(dae!, mass_matrix=Diagonal([1.0, 0.0]))
calc_u0(c::Coeffs) = [2.0, -2.0]
function OrdinaryDiffEqRosenbrock.ODEProblem(c::Coeffs)
u0 = calc_u0(c)
return ODEProblem{true}(dae_fc, u0, (0.0, 1.0), c; initializealg=CheckInit())
# return ODEProblem{true}(dae_fc, u0, (0.0, 1.0), c;)
end With this example, @code_warntype solve(ODEProblem(Coeffs(-1.0, 1.0)), Rosenbrock23()) I get this, where it's a little hard to tell where the instability happens:
Dropping the kwarg for |
Using prob = ODEProblem(cbase)
@report_opt solve(prob)
And if I load ODENonlinearSolve and drop CheckInit as kwarg, then JET reports many more possible errors (321). If it matters, these solves get put inside a call like below, but the type stability happens already at the solve step. In this case I can annotate the return type to match the input type, but it would be nice for that to be inferrable function eval_prm_ccb(pr, c::Coeffs)
# rtype == typeof(pr)
newc = @set c.p1 = pr
newprob = ODEProblem(newc)
sol = solve(newprob, Rosenbrock23(); callback=ccallback)
evalsol(sol)#::rtype
end |
Exploring this MWE with Cthulhu, it looks like maybe this is related to #2613 ? Inside of which (if I am exploring this correctly, which I might not be) eventually leads to a culprit here: OrdinaryDiffEq.jl/lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl Lines 74 to 87 in 786689e
|
Describe the bug 🐞
DAE solutions with callbacks work, but they are apparently type unstable, according to
@code_warntype
. ODE solutions with callbacks and DAE solutions without callbacks without are both type stable.See also #2530 , #2558 .
Minimal Reproducible Example 👇
Error & Stacktrace⚠️
Without callback, is type stable:
With callback, is unstable:
Environment (please complete the following information):
using Pkg; Pkg.status()
using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
Updated today. The SciML and OrdinaryDiffEq packages:
The text was updated successfully, but these errors were encountered: