Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ImplictiEulerExtrapolation parallel #872

Merged
merged 2 commits into from
Aug 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,15 @@ struct ImplicitEulerExtrapolation{CS,AD,F,F2} <: OrdinaryDiffEqImplicitExtrapola
max_order::Int
min_order::Int
init_order::Int
threading::Bool
end


ImplicitEulerExtrapolation(;chunk_size=0,autodiff=true,diff_type=Val{:forward},
linsolve=DEFAULT_LINSOLVE,
max_order=10,min_order=1,init_order=5) =
max_order=10,min_order=1,init_order=5,threading=true) =
ImplicitEulerExtrapolation{chunk_size,autodiff,
typeof(linsolve),typeof(diff_type)}(linsolve,max_order,min_order,init_order)
typeof(linsolve),typeof(diff_type)}(linsolve,max_order,min_order,init_order,threading)

struct ExtrapolationMidpointDeuflhard <: OrdinaryDiffEqExtrapolationVarOrderVarStepAlgorithm
n_min::Int # Minimal extrapolation order
Expand Down
42 changes: 35 additions & 7 deletions src/caches/extrapolation_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,12 @@ end
@cache mutable struct ImplicitEulerExtrapolationCache{uType,rateType,arrayType,dtType,JType,WType,F,JCType,GCType,uNoUnitsType,TFType,UFType} <: OrdinaryDiffEqMutableCache
uprev::uType
u_tmp::uType
saurabhkgp21 marked this conversation as resolved.
Show resolved Hide resolved
u_tmps::Array{uType,1}
utilde::uType
tmp::uType
atmp::uNoUnitsType
k_tmp::rateType
saurabhkgp21 marked this conversation as resolved.
Show resolved Hide resolved
k_tmps::Array{rateType,1}
dtpropose::dtType
T::arrayType
cur_order::Int
Expand All @@ -89,7 +91,8 @@ end
tf::TFType
uf::UFType
linsolve_tmp::rateType
saurabhkgp21 marked this conversation as resolved.
Show resolved Hide resolved
linsolve::F
linsolve_tmps::Array{rateType,1}
linsolve::Array{F,1}
jac_config::JCType
grad_config::GCType
end
Expand Down Expand Up @@ -121,10 +124,21 @@ end

function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
u_tmp = similar(u)
u_tmps = Array{typeof(u_tmp),1}(undef, Threads.nthreads())

for i=1:Threads.nthreads()
u_tmps[i] = zero(u_tmp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one improvement that here is just to make the first ones in the arrays equal, i.e. u_tmps[1] = u_tmp. Then the algorithm would be fine. Right now you have an extra thing for each one, so you have an extra Jacobian, and that's really bad memory-wise.

end

utilde = similar(u)
tmp = similar(u)
k = zero(rate_prototype)
k_tmp = zero(rate_prototype)
k_tmps = Array{typeof(k_tmp),1}(undef, Threads.nthreads())

for i=1:Threads.nthreads()
k_tmps[i] = zero(rate_prototype)
end

cur_order = max(alg.init_order, alg.min_order)
dtpropose = zero(dt)
T = Array{typeof(u),2}(undef, alg.max_order, alg.max_order)
Expand All @@ -143,22 +157,36 @@ function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUni
du2 = zero(rate_prototype)

if DiffEqBase.has_jac(f) && !DiffEqBase.has_Wfact(f) && f.jac_prototype !== nothing
W = WOperator(f, dt, true)
W_el = WOperator(f, dt, true)
J = nothing # is J = W.J better?
else
J = false .* vec(rate_prototype) .* vec(rate_prototype)' # uEltype?
W = similar(J)
W_el = similar(J)
end
W = Array{typeof(W_el),1}(undef, Threads.nthreads())
for i=1:Threads.nthreads()
W[i] = zero(W_el)
end
tf = DiffEqDiffTools.TimeGradientWrapper(f,uprev,p)
uf = DiffEqDiffTools.UJacobianWrapper(f,t,p)
linsolve_tmp = zero(rate_prototype)
linsolve = alg.linsolve(Val{:init},uf,u)
linsolve_tmps = Array{typeof(linsolve_tmp),1}(undef, Threads.nthreads())

for i=1:Threads.nthreads()
linsolve_tmps[i] = zero(rate_prototype)
end

linsolve_el = alg.linsolve(Val{:init},uf,u)
linsolve = Array{typeof(linsolve_el),1}(undef, Threads.nthreads())
for i=1:Threads.nthreads()
linsolve[i] = alg.linsolve(Val{:init},uf,u)
end
grad_config = build_grad_config(alg,f,tf,du1,t)
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,du1,du2)


ImplicitEulerExtrapolationCache(uprev,u_tmp,utilde,tmp,atmp,k_tmp,dtpropose,T,cur_order,work,A,step_no,
du1,du2,J,W,tf,uf,linsolve_tmp,linsolve,jac_config,grad_config)
ImplicitEulerExtrapolationCache(uprev,u_tmp,u_tmps,utilde,tmp,atmp,k_tmp,k_tmps,dtpropose,T,cur_order,work,A,step_no,
du1,du2,J,W,tf,uf,linsolve_tmp,linsolve_tmps,linsolve,jac_config,grad_config)
end


Expand Down
51 changes: 51 additions & 0 deletions src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,57 @@ function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_
return nothing
end

function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_step, W_index::Int, W_transform=false)
@unpack t,dt,uprev,u,f,p = integrator
@unpack J,W = cache
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just silly. Why not just pass in W into these routines? It would fix these problems.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kanav99 @huanglangwen this is something we should follow up on. It's because our original calc_W! was never intended for thread-safety, and so to make it thread-safe there is this code duplication. We should assume this case can happen in the next refactor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean calc_W! should have an additional parameter of readonly W ?

Copy link
Member

@ChrisRackauckas ChrisRackauckas Aug 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should have J and W I think, where it writes into W for a given J, instead of pulling them from the cache (since there may be more than 1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it won't work on OOP.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For OOP it will need to return W.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have a pseudo-nlsolver kind of thing for algorithms which don't have an actual nlsolver like Rosenbrocks, the struct should have aliases to the J and W and similar stuff, should be <: NLSolver so that we have a common implementation of the functions. This way it won't make us change the original derivative utilities everytime we make a new algorithm just for an overload.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that algorithms with nlsolver struct are already well handled for parallel applications.

alg = unwrap_alg(integrator, true)
mass_matrix = integrator.f.mass_matrix
is_compos = integrator.alg isa CompositeAlgorithm
isnewton = alg isa NewtonAlgorithm

if W_transform && DiffEqBase.has_Wfact_t(f)
f.Wfact_t(W[W_index], u, p, dtgamma, t)
is_compos && (integrator.eigen_est = opnorm(LowerTriangular(W[W_index]), Inf) + inv(dtgamma)) # TODO: better estimate
return nothing
elseif !W_transform && DiffEqBase.has_Wfact(f)
f.Wfact(W[W_index], u, p, dtgamma, t)
if is_compos
opn = opnorm(LowerTriangular(W[W_index]), Inf)
integrator.eigen_est = (opn + one(opn)) / dtgamma # TODO: better estimate
end
return nothing
end

# fast pass
# we only want to factorize the linear operator once
new_jac = true
new_W = true
if (f isa ODEFunction && islinear(f.f)) || (integrator.alg isa SplitAlgorithms && f isa SplitFunction && islinear(f.f1.f))
new_jac = false
@goto J2W # Jump to W calculation directly, because we already have J
end

# check if we need to update J or W
W_dt = isnewton ? cache.nlsolver.cache.W_dt : dt # TODO: RosW
new_jac = isnewton ? do_newJ(integrator, alg, cache, repeat_step) : true
new_W = isnewton ? do_newW(integrator, cache.nlsolver, new_jac, W_dt) : true

# calculate W
if DiffEqBase.has_jac(f) && f.jac_prototype !== nothing && !ArrayInterface.isstructured(f.jac_prototype)
isnewton || DiffEqBase.update_coefficients!(W[W_index],uprev,p,t) # we will call `update_coefficients!` in NLNewton
@label J2W
W[W_index].transform = W_transform; set_gamma!(W[W_index], dtgamma)
else # concrete W using jacobian from `calc_J!`
new_jac && calc_J!(integrator, cache, is_compos)
new_W && jacobian2W!(W[W_index], mass_matrix, dtgamma, J, W_transform)
end
if isnewton
set_new_W!(cache.nlsolver, new_W) && DiffEqBase.set_W_dt!(cache.nlsolver, dt)
end
new_W && (integrator.destats.nw += 1)
return nothing
end

function calc_W!(nlsolver, integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_step, W_transform=false)
@unpack t,dt,uprev,u,f,p = integrator
@unpack J,W = nlsolver.cache
Expand Down
84 changes: 53 additions & 31 deletions src/perform_step/extrapolation_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,27 +249,39 @@ function perform_step!(integrator,cache::ImplicitEulerExtrapolationCache,repeat_
@unpack t,dt,uprev,u,f,p = integrator
@unpack u_tmp,k_tmp,T,utilde,atmp,dtpropose,cur_order,A = cache
@unpack J,W,uf,tf,linsolve_tmp,jac_config = cache
@unpack u_tmps, k_tmps, linsolve_tmps = cache

max_order = min(size(T)[1],cur_order+1)

for i in 1:max_order
dt_temp = dt/(2^(i-1)) # Romberg sequence
calc_W!(integrator, cache, dt_temp, repeat_step)
k_tmp = copy(integrator.fsalfirst)
u_tmp = copy(uprev)
for j in 1:2^(i-1)
linsolve_tmp = dt_temp*k_tmp
cache.linsolve(vec(k_tmp), W, vec(linsolve_tmp), !repeat_step)
@.. k_tmp = -k_tmp
@.. u_tmp = u_tmp + k_tmp
f(k_tmp, u_tmp,p,t+j*dt_temp)
end
let max_order=max_order, uprev=uprev, dt=dt, p=p, t=t, T=T, W=W,
integrator=integrator, cache=cache, repeat_step = repeat_step,
k_tmps=k_tmps, u_tmps=u_tmps
Threads.@threads for i in 1:2
startIndex = (i == 1) ? 1 : max_order
endIndex = (i == 1) ? max_order - 1 : max_order
for index in startIndex:endIndex
dt_temp = dt/(2^(index-1)) # Romberg sequence
calc_W!(integrator, cache, dt_temp, repeat_step, Threads.threadid())
k_tmps[Threads.threadid()] = copy(integrator.fsalfirst)
u_tmps[Threads.threadid()] = copy(uprev)
for j in 1:2^(index-1)
@.. linsolve_tmps[Threads.threadid()] = dt_temp*k_tmps[Threads.threadid()]
cache.linsolve[Threads.threadid()](vec(k_tmps[Threads.threadid()]), W[Threads.threadid()], vec(linsolve_tmps[Threads.threadid()]), !repeat_step)
@.. k_tmps[Threads.threadid()] = -k_tmps[Threads.threadid()]
@.. u_tmps[Threads.threadid()] = u_tmps[Threads.threadid()] + k_tmps[Threads.threadid()]
f(k_tmps[Threads.threadid()], u_tmps[Threads.threadid()],p,t+j*dt_temp)
end

@.. T[i,1] = u_tmp
for j in 2:i
@.. T[i,j] = ((2^(j-1))*T[i,j-1] - T[i-1,j-1])/((2^(j-1)) - 1)
@.. T[index,1] = u_tmps[Threads.threadid()]
end
end
for i in 2:max_order
for j in 2:i
@.. T[i,j] = ((2^(j-1))*T[i,j-1] - T[i-1,j-1])/((2^(j-1)) - 1)
end
end
end

integrator.dt = dt

if integrator.opts.adaptive
Expand Down Expand Up @@ -332,23 +344,33 @@ function perform_step!(integrator,cache::ImplicitEulerExtrapolationConstantCache

max_order = min(size(T)[1], cur_order+1)

for i in 1:max_order
dt_temp = dt/(2^(i-1)) # Romberg sequence
W = calc_W!(integrator, cache, dt_temp, repeat_step)
k_copy = integrator.fsalfirst
u_tmp = uprev
for j in 1:2^(i-1)
k = _reshape(W\-_vec(dt_temp*k_copy), axes(uprev))
integrator.destats.nsolve += 1
u_tmp = u_tmp + k
k_copy = f(u_tmp, p, t+j*dt_temp)
let max_order=max_order, dt=dt, integrator=integrator, cache=cache, repeat_step=repeat_step,
uprev=uprev, T=T
Threads.@threads for i in 1:2
startIndex = (i==1) ? 1 : max_order
endIndex = (i==1) ? max_order-1 : max_order
for index in startIndex:endIndex
dt_temp = dt/(2^(index-1)) # Romberg sequence
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why Romberg?

W = calc_W!(integrator, cache, dt_temp, repeat_step)
k_copy = integrator.fsalfirst
u_tmp = uprev
for j in 1:2^(index-1)
k = _reshape(W\-_vec(dt_temp*k_copy), axes(uprev))
integrator.destats.nsolve += 1
u_tmp = u_tmp + k
k_copy = f(u_tmp, p, t+j*dt_temp)
end
T[index,1] = u_tmp
end
end
T[i,1] = u_tmp
# Richardson Extrapolation
for j in 2:i
T[i,j] = ((2^(j-1))*T[i,j-1] - T[i-1,j-1])/((2^(j-1)) - 1)

for i=2:max_order
for j=2:i
T[i,j] = ((2^(j-1))*T[i,j-1] - T[i-1,j-1])/((2^(j-1)) - 1)
end
end
end

integrator.destats.nf += 2^(max_order) - 1
integrator.dt = dt

Expand Down Expand Up @@ -391,9 +413,9 @@ function perform_step!(integrator,cache::ImplicitEulerExtrapolationConstantCache

# Use extrapolated value of u
integrator.u = T[cache.cur_order, cache.cur_order]
k = f(integrator.u, p, t+dt)
k_temp = f(integrator.u, p, t+dt)
integrator.destats.nf += 1
integrator.fsallast = k
integrator.fsallast = k_temp
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
end
Expand Down