Skip to content

Commit

Permalink
attempt to precompile linsolve (#698)
Browse files Browse the repository at this point in the history
* attempt to precompile linsolve

```julia
using OrdinaryDiffEq, SnoopCompile

function lorenz(du,u,p,t)
 du[1] = 10.0(u[2]-u[1])
 du[2] = u[1]*(28.0-u[3]) - u[2]
 du[3] = u[1]*u[2] - (8/3)*u[3]
end

u0 = [1.0;0.0;0.0]
tspan = (0.0,100.0)
prob = ODEProblem(lorenz,u0,tspan)
alg = Rodas5()
tinf = @snoopi_deep solve(prob,alg)
```

* Update src/DiffEqBase.jl

Co-authored-by: Tim Holy <tim.holy@gmail.com>

* make precompilation a bit safer

* copyto

* fix precompile

* fix current state

* remove invalidation of mapreduce_empty

* Force Matrix{Float64}

* hoist chunk size choice earlier via `prepare_alg`

* bring invalidation back so tests don't fail

* fix copyto! -> fill!

Co-authored-by: Tim Holy <tim.holy@gmail.com>
  • Loading branch information
ChrisRackauckas and timholy authored Aug 12, 2021
1 parent 9d3aa5b commit 0227b8d
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 7 deletions.
2 changes: 2 additions & 0 deletions src/DiffEqBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ include("init.jl")
include("forwarddiff.jl")
include("chainrules.jl")

include("precompile.jl")

"""
$(TYPEDEF)
"""
Expand Down
28 changes: 25 additions & 3 deletions src/linear_nonlinear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ DefaultLinSolve() = DefaultLinSolve(nothing, nothing, nothing)
end

function isopenblas()
@static if VERSION < v"1.7"
@static if VERSION < v"1.7beta"
blas = BLAS.vendor()
blas == :openblas64 || blas == :openblas
else
Expand Down Expand Up @@ -131,7 +131,7 @@ function (p::DefaultLinSolve)(x,A,b,update_matrix=false;reltol=nothing, kwargs..
end

if A isa Union{Matrix,SymTridiagonal,Tridiagonal,Symmetric,Hermitian,ForwardSensitivityJacobian} # No 2-arg form for SparseArrays!
x .= b
copyto!(x,b)
ldiv!(p.A,x)
# Missing a little bit of efficiency in a rare case
#elseif A isa DiffEqArrayOperator
Expand All @@ -144,7 +144,7 @@ function (p::DefaultLinSolve)(x,A,b,update_matrix=false;reltol=nothing, kwargs..
reltol = checkreltol(reltol)
p.iterable = IterativeSolvers.gmres_iterable!(x,A,b;initially_zero=true,restart=5,maxiter=5,abstol=1e-16,reltol=reltol,kwargs...)
end
x .= false
fill!(x,false)
iter = p.iterable
purge_history!(iter, x, b)

Expand All @@ -168,6 +168,28 @@ end
Base.resize!(p::DefaultLinSolve,i) = p.A = nothing
const DEFAULT_LINSOLVE = DefaultLinSolve()

## A much simpler LU for when we know it's Array

mutable struct LUFactorize
A::LU{Float64,Matrix{Float64}}
openblas::Bool
end
LUFactorize() = LUFactorize(lu(rand(1,1)),isopenblas())
function (p::LUFactorize)(x::Vector{Float64},A::Matrix{Float64},b::Vector{Float64},update_matrix::Bool=false;kwargs...)
if update_matrix
if ArrayInterface.can_setindex(x) && (size(A,1) <= 100 || (p.openblas && size(A,1) <= 500))
p.A = RecursiveFactorization.lu!(A)
else
p.A = lu!(A)
end
end
ldiv!(x,p.A,b)
end
function (p::LUFactorize)(::Type{Val{:init}},f,u0_prototype)
LUFactorize(lu(rand(eltype(u0_prototype),1,1)),p.openblas)
end
Base.resize!(p::LUFactorize,i) = p.A = nothing

### Default GMRES

# Easily change to GMRES
Expand Down
19 changes: 19 additions & 0 deletions src/precompile.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
let
while true
_testf(du,u,p,t) = copyto!(du,u)
b = rand(1); x = rand(1)
_linsolve = DEFAULT_LINSOLVE(Val{:init},ODEFunction(_testf),b)
A = rand(1,1)
_linsolve(x,A,b,true)
_linsolve(x,A,b,false)
_linsolve = LUFactorize()(Val{:init},ODEFunction(_testf),b)
_linsolve(x,A,b,true)
_linsolve(x,A,b,false)
Pl = ScaleVector([1.0],true)
Pr = ScaleVector([1.0],false)
reltol = 1.0
_linsolve(x,A,b,true;reltol=reltol,Pl=Pl,Pr=Pr)
_linsolve(x,A,b,false;reltol=reltol,Pl=Pl,Pr=Pr)
break
end
end
11 changes: 7 additions & 4 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,14 @@ function solve_up(prob::DEProblem,sensealg,u0,p,args...;kwargs...)

if haskey(kwargs,:alg) && (isempty(args) || args[1] === nothing)
alg = kwargs[:alg]
_prob = get_concrete_problem(prob,isadaptive(alg);u0=u0,p=p,kwargs...)
solve_call(_prob,alg,args...;kwargs...)
_alg = prepare_alg(alg,u0,p,prob)
_prob = get_concrete_problem(prob,isadaptive(_alg);u0=u0,p=p,kwargs...)
solve_call(_prob,_alg,args...;kwargs...)
elseif !isempty(args) && typeof(args[1]) <: DEAlgorithm
alg = args[1]
_prob = get_concrete_problem(prob,isadaptive(alg);u0=u0,p=p,kwargs...)
solve_call(_prob,args...;kwargs...)
_alg = prepare_alg(alg,u0,p,prob)
_prob = get_concrete_problem(prob,isadaptive(_alg);u0=u0,p=p,kwargs...)
solve_call(_prob,_alg,Base.tail(args)...;kwargs...)
elseif isempty(args) # Default algorithm handling
_prob = get_concrete_problem(prob,!(typeof(prob)<:DiscreteProblem);u0=u0,p=p,kwargs...)
solve_call(_prob,args...;kwargs...)
Expand Down Expand Up @@ -203,6 +205,7 @@ function promote_f(f,u0)
end

promote_f(f::SplitFunction,u0) = typeof(f.cache) === typeof(u0) && isinplace(f) ? f : remake(f,cache=zero(u0))
prepare_alg(alg,u0,p,f) = alg

function get_concrete_tspan(prob, isadapt, kwargs, p)
if prob.tspan isa Function
Expand Down

0 comments on commit 0227b8d

Please sign in to comment.