Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EnsembleProblem solving time increases by one order of magnitude when modelingtoolkitizing #949

Closed
gabrevaya opened this issue Mar 27, 2023 · 5 comments

Comments

@gabrevaya
Copy link

I was surprised to find that changing the parameters and initial conditions with a basic for-loop was resulting in much better performance than using EnsembleProblem. Eventually, I realized that it was due to the usage of ModelingToolkit for "optimizing" the ODEProblem, which was actually causing the opposite effect! Here is an MWE:

using DifferentialEquations
using ModelingToolkit
using BenchmarkTools

function ensemble_solve(prob, solver, u0s, ps, t)
    u = similar(prob.u0, length(prob.u0), length(t), size(u0s, 2))
    for i in 1:size(u0s, 2)
        prob = remake(prob, p = ps[:, i], u0 = u0s[:,i])
        sol = solve(prob, solver; saveat = t)
        u[:, :, i] = Array(sol)
    end
    return u
end

function ensemble_solve2(prob, solver, u0s, ps, t)
    prob_func(prob, i, repeat) = remake(prob, u0=u0s[:, i], p = ps[:, i])
    output_func(sol, i) = (Array(sol), false)
    ens_prob = EnsembleProblem(prob, prob_func = prob_func, output_func = output_func)
    sol = solve(ens_prob, solver, EnsembleSerial(); trajectories = size(u0s, 2), saveat = t)
    Array(sol)
end

function f(du,u,p,t)
    du[1] = p[1] * u[1] - p[2] * u[1]*u[2]
    du[2] = -3 * u[2] + u[1]*u[2]
end

p = Float32[1.5,1.0,0.1,0.1]
u0 = Float32[1.0, 1.0]
t = 0f0:0.01f0:1f0
tspan = (t[1], t[end])
prob = ODEProblem(f,u0, tspan, p)
solver = Tsit5()

N = 1000
ps = repeat(p, inner = (1, N));
u0s = repeat(u0, inner = (1, N));

@btime ensemble_solve($prob, $solver, $u0s, $ps, $t);
# 14.017 ms (153002 allocations: 13.69 MiB)

@btime ensemble_solve2($prob, $solver, $u0s, $ps, $t);
# 14.934 ms (165998 allocations: 15.26 MiB)

sys = modelingtoolkitize(prob)
ODEFunc = ODEFunction{true}(sys, tgrad=true, jac = true, sparse = false, simplify = true)
prob = ODEProblem{true}(ODEFunc, u0, tspan, p)

@btime ensemble_solve($prob, $solver, $u0s, $ps, $t);
# 19.180 ms (145002 allocations: 15.34 MiB)

@btime ensemble_solve2($prob, $solver, $u0s, $ps, $t);
# 242.036 ms (1652998 allocations: 125.92 MiB)
(ensemble_modelingtoolkitize) pkg> st
Status `~/Documents/doctorado/issues/ensemble_modelingtoolkitize/Project.toml`
  [6e4b80f9] BenchmarkTools v1.3.2
  [0c46a032] DifferentialEquations v7.7.0
  [961ee093] ModelingToolkit v8.49.0
julia> versioninfo()
Julia Version 1.8.5
Commit 17cfb8e65ea (2023-01-08 06:45 UTC)
Platform Info:
  OS: macOS (arm64-apple-darwin21.5.0)
  CPU: 8 × Apple M1 Pro
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, apple-m1)
  Threads: 6 on 6 virtual cores
Environment:
  JULIA_EDITOR = code
@gabrevaya
Copy link
Author

Here I include the profiling of the 4 cases:
profiling.zip

for_loop.html -> ensemble_solve
EnsembleProblem.html -> ensemble_solve2

for_loop_ModelingToolkit.html -> ensemble_solve after using modelingtoolkitize
EnsembleProblem_ModelingToolkit.html -> ensemble_solve2 after using modelingtoolkitize

It seems that the what is taking most of the time in EnsembleProblem_ModelingToolkit.html is a call to deepcopy, which is not present in the EnsembleProblem version when when modelingtoolkitize was not used. For convenience, bellow I also share some screenshots of the parts where EnsembleProblem_ModelingToolkit.html and EnsembleProblem.html being to differ:

From EnsembleProblem.html
diff_EnsembleProblem

From EnsembleProblem_ModelingToolkit.html
diff_EnsembleProblem_ModelingToolkit

@gabrevaya
Copy link
Author

This is due to line 89 from SciMLBase/src/ensemble/basic_ensemble_solve.jl:

_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob

What is the reason behind prob.safetycopy? Is a deepcopy of prob actually needed in this case or is it a bug?

@ChrisRackauckas
Copy link
Member

It's needed if there are any caches that are modified. That can happen in people's mutating code, mutating a p cache. It cannot happen from the MTK generated code though, so we could add a flag to automatically disable the safetycopy when we generate the code.

@gabrevaya
Copy link
Author

Great, thanks a lot!! :)

I've just checked and setting safetycopy=false when building the EnsembleProblem as indicated in the docs indeed gets the EnsembleProblem solving and the for-loop solving to similar performances:

@btime ensemble_solve($prob, $solver, $u0s, $ps, $t);
# 19.065 ms (145002 allocations: 15.34 MiB)

@btime ensemble_solve2($prob, $solver, $u0s, $ps, $t);
# 19.141 ms (147998 allocations: 16.22 MiB)

However, oddly, their performances are a bit lower respect to not using MTK:

@btime ensemble_solve($prob, $solver, $u0s, $ps, $t);
# 13.890 ms (153002 allocations: 13.69 MiB)

@btime ensemble_solve2($prob, $solver, $u0s, $ps, $t);
# 13.899 ms (155998 allocations: 14.57 MiB)

I know its unrelated to the current issue (which is solved!) but do you have any idea why this could be happening?

@ChrisRackauckas
Copy link
Member

I know its unrelated to the current issue (which is solved!) but do you have any idea why this could be happening?

That could happen if MTK ends up with less SIMD.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants