Skip to content

Commit

Permalink
Merge pull request #872 from saurabhkgp21/implicitEulerParaller
Browse files Browse the repository at this point in the history
ImplictiEulerExtrapolation parallel
  • Loading branch information
ChrisRackauckas authored Aug 16, 2019
2 parents b4de315 + 62195d2 commit 280bddd
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 40 deletions.
5 changes: 3 additions & 2 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,15 @@ struct ImplicitEulerExtrapolation{CS,AD,F,F2} <: OrdinaryDiffEqImplicitExtrapola
max_order::Int
min_order::Int
init_order::Int
threading::Bool
end


ImplicitEulerExtrapolation(;chunk_size=0,autodiff=true,diff_type=Val{:forward},
linsolve=DEFAULT_LINSOLVE,
max_order=10,min_order=1,init_order=5) =
max_order=10,min_order=1,init_order=5,threading=true) =
ImplicitEulerExtrapolation{chunk_size,autodiff,
typeof(linsolve),typeof(diff_type)}(linsolve,max_order,min_order,init_order)
typeof(linsolve),typeof(diff_type)}(linsolve,max_order,min_order,init_order,threading)

struct ExtrapolationMidpointDeuflhard <: OrdinaryDiffEqExtrapolationVarOrderVarStepAlgorithm
n_min::Int # Minimal extrapolation order
Expand Down
42 changes: 35 additions & 7 deletions src/caches/extrapolation_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,12 @@ end
@cache mutable struct ImplicitEulerExtrapolationCache{uType,rateType,arrayType,dtType,JType,WType,F,JCType,GCType,uNoUnitsType,TFType,UFType} <: OrdinaryDiffEqMutableCache
uprev::uType
u_tmp::uType
u_tmps::Array{uType,1}
utilde::uType
tmp::uType
atmp::uNoUnitsType
k_tmp::rateType
k_tmps::Array{rateType,1}
dtpropose::dtType
T::arrayType
cur_order::Int
Expand All @@ -89,7 +91,8 @@ end
tf::TFType
uf::UFType
linsolve_tmp::rateType
linsolve::F
linsolve_tmps::Array{rateType,1}
linsolve::Array{F,1}
jac_config::JCType
grad_config::GCType
end
Expand Down Expand Up @@ -121,10 +124,21 @@ end

function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
u_tmp = similar(u)
u_tmps = Array{typeof(u_tmp),1}(undef, Threads.nthreads())

for i=1:Threads.nthreads()
u_tmps[i] = zero(u_tmp)
end

utilde = similar(u)
tmp = similar(u)
k = zero(rate_prototype)
k_tmp = zero(rate_prototype)
k_tmps = Array{typeof(k_tmp),1}(undef, Threads.nthreads())

for i=1:Threads.nthreads()
k_tmps[i] = zero(rate_prototype)
end

cur_order = max(alg.init_order, alg.min_order)
dtpropose = zero(dt)
T = Array{typeof(u),2}(undef, alg.max_order, alg.max_order)
Expand All @@ -143,22 +157,36 @@ function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUni
du2 = zero(rate_prototype)

if DiffEqBase.has_jac(f) && !DiffEqBase.has_Wfact(f) && f.jac_prototype !== nothing
W = WOperator(f, dt, true)
W_el = WOperator(f, dt, true)
J = nothing # is J = W.J better?
else
J = false .* vec(rate_prototype) .* vec(rate_prototype)' # uEltype?
W = similar(J)
W_el = similar(J)
end
W = Array{typeof(W_el),1}(undef, Threads.nthreads())
for i=1:Threads.nthreads()
W[i] = zero(W_el)
end
tf = DiffEqDiffTools.TimeGradientWrapper(f,uprev,p)
uf = DiffEqDiffTools.UJacobianWrapper(f,t,p)
linsolve_tmp = zero(rate_prototype)
linsolve = alg.linsolve(Val{:init},uf,u)
linsolve_tmps = Array{typeof(linsolve_tmp),1}(undef, Threads.nthreads())

for i=1:Threads.nthreads()
linsolve_tmps[i] = zero(rate_prototype)
end

linsolve_el = alg.linsolve(Val{:init},uf,u)
linsolve = Array{typeof(linsolve_el),1}(undef, Threads.nthreads())
for i=1:Threads.nthreads()
linsolve[i] = alg.linsolve(Val{:init},uf,u)
end
grad_config = build_grad_config(alg,f,tf,du1,t)
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,du1,du2)


ImplicitEulerExtrapolationCache(uprev,u_tmp,utilde,tmp,atmp,k_tmp,dtpropose,T,cur_order,work,A,step_no,
du1,du2,J,W,tf,uf,linsolve_tmp,linsolve,jac_config,grad_config)
ImplicitEulerExtrapolationCache(uprev,u_tmp,u_tmps,utilde,tmp,atmp,k_tmp,k_tmps,dtpropose,T,cur_order,work,A,step_no,
du1,du2,J,W,tf,uf,linsolve_tmp,linsolve_tmps,linsolve,jac_config,grad_config)
end


Expand Down
51 changes: 51 additions & 0 deletions src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,57 @@ function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_
return nothing
end

function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_step, W_index::Int, W_transform=false)
@unpack t,dt,uprev,u,f,p = integrator
@unpack J,W = cache
alg = unwrap_alg(integrator, true)
mass_matrix = integrator.f.mass_matrix
is_compos = integrator.alg isa CompositeAlgorithm
isnewton = alg isa NewtonAlgorithm

if W_transform && DiffEqBase.has_Wfact_t(f)
f.Wfact_t(W[W_index], u, p, dtgamma, t)
is_compos && (integrator.eigen_est = opnorm(LowerTriangular(W[W_index]), Inf) + inv(dtgamma)) # TODO: better estimate
return nothing
elseif !W_transform && DiffEqBase.has_Wfact(f)
f.Wfact(W[W_index], u, p, dtgamma, t)
if is_compos
opn = opnorm(LowerTriangular(W[W_index]), Inf)
integrator.eigen_est = (opn + one(opn)) / dtgamma # TODO: better estimate
end
return nothing
end

# fast pass
# we only want to factorize the linear operator once
new_jac = true
new_W = true
if (f isa ODEFunction && islinear(f.f)) || (integrator.alg isa SplitAlgorithms && f isa SplitFunction && islinear(f.f1.f))
new_jac = false
@goto J2W # Jump to W calculation directly, because we already have J
end

# check if we need to update J or W
W_dt = isnewton ? cache.nlsolver.cache.W_dt : dt # TODO: RosW
new_jac = isnewton ? do_newJ(integrator, alg, cache, repeat_step) : true
new_W = isnewton ? do_newW(integrator, cache.nlsolver, new_jac, W_dt) : true

# calculate W
if DiffEqBase.has_jac(f) && f.jac_prototype !== nothing && !ArrayInterface.isstructured(f.jac_prototype)
isnewton || DiffEqBase.update_coefficients!(W[W_index],uprev,p,t) # we will call `update_coefficients!` in NLNewton
@label J2W
W[W_index].transform = W_transform; set_gamma!(W[W_index], dtgamma)
else # concrete W using jacobian from `calc_J!`
new_jac && calc_J!(integrator, cache, is_compos)
new_W && jacobian2W!(W[W_index], mass_matrix, dtgamma, J, W_transform)
end
if isnewton
set_new_W!(cache.nlsolver, new_W) && DiffEqBase.set_W_dt!(cache.nlsolver, dt)
end
new_W && (integrator.destats.nw += 1)
return nothing
end

function calc_W!(nlsolver, integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_step, W_transform=false)
@unpack t,dt,uprev,u,f,p = integrator
@unpack J,W = nlsolver.cache
Expand Down
84 changes: 53 additions & 31 deletions src/perform_step/extrapolation_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,27 +249,39 @@ function perform_step!(integrator,cache::ImplicitEulerExtrapolationCache,repeat_
@unpack t,dt,uprev,u,f,p = integrator
@unpack u_tmp,k_tmp,T,utilde,atmp,dtpropose,cur_order,A = cache
@unpack J,W,uf,tf,linsolve_tmp,jac_config = cache
@unpack u_tmps, k_tmps, linsolve_tmps = cache

max_order = min(size(T)[1],cur_order+1)

for i in 1:max_order
dt_temp = dt/(2^(i-1)) # Romberg sequence
calc_W!(integrator, cache, dt_temp, repeat_step)
k_tmp = copy(integrator.fsalfirst)
u_tmp = copy(uprev)
for j in 1:2^(i-1)
linsolve_tmp = dt_temp*k_tmp
cache.linsolve(vec(k_tmp), W, vec(linsolve_tmp), !repeat_step)
@.. k_tmp = -k_tmp
@.. u_tmp = u_tmp + k_tmp
f(k_tmp, u_tmp,p,t+j*dt_temp)
end
let max_order=max_order, uprev=uprev, dt=dt, p=p, t=t, T=T, W=W,
integrator=integrator, cache=cache, repeat_step = repeat_step,
k_tmps=k_tmps, u_tmps=u_tmps
Threads.@threads for i in 1:2
startIndex = (i == 1) ? 1 : max_order
endIndex = (i == 1) ? max_order - 1 : max_order
for index in startIndex:endIndex
dt_temp = dt/(2^(index-1)) # Romberg sequence
calc_W!(integrator, cache, dt_temp, repeat_step, Threads.threadid())
k_tmps[Threads.threadid()] = copy(integrator.fsalfirst)
u_tmps[Threads.threadid()] = copy(uprev)
for j in 1:2^(index-1)
@.. linsolve_tmps[Threads.threadid()] = dt_temp*k_tmps[Threads.threadid()]
cache.linsolve[Threads.threadid()](vec(k_tmps[Threads.threadid()]), W[Threads.threadid()], vec(linsolve_tmps[Threads.threadid()]), !repeat_step)
@.. k_tmps[Threads.threadid()] = -k_tmps[Threads.threadid()]
@.. u_tmps[Threads.threadid()] = u_tmps[Threads.threadid()] + k_tmps[Threads.threadid()]
f(k_tmps[Threads.threadid()], u_tmps[Threads.threadid()],p,t+j*dt_temp)
end

@.. T[i,1] = u_tmp
for j in 2:i
@.. T[i,j] = ((2^(j-1))*T[i,j-1] - T[i-1,j-1])/((2^(j-1)) - 1)
@.. T[index,1] = u_tmps[Threads.threadid()]
end
end
for i in 2:max_order
for j in 2:i
@.. T[i,j] = ((2^(j-1))*T[i,j-1] - T[i-1,j-1])/((2^(j-1)) - 1)
end
end
end

integrator.dt = dt

if integrator.opts.adaptive
Expand Down Expand Up @@ -332,23 +344,33 @@ function perform_step!(integrator,cache::ImplicitEulerExtrapolationConstantCache

max_order = min(size(T)[1], cur_order+1)

for i in 1:max_order
dt_temp = dt/(2^(i-1)) # Romberg sequence
W = calc_W!(integrator, cache, dt_temp, repeat_step)
k_copy = integrator.fsalfirst
u_tmp = uprev
for j in 1:2^(i-1)
k = _reshape(W\-_vec(dt_temp*k_copy), axes(uprev))
integrator.destats.nsolve += 1
u_tmp = u_tmp + k
k_copy = f(u_tmp, p, t+j*dt_temp)
let max_order=max_order, dt=dt, integrator=integrator, cache=cache, repeat_step=repeat_step,
uprev=uprev, T=T
Threads.@threads for i in 1:2
startIndex = (i==1) ? 1 : max_order
endIndex = (i==1) ? max_order-1 : max_order
for index in startIndex:endIndex
dt_temp = dt/(2^(index-1)) # Romberg sequence
W = calc_W!(integrator, cache, dt_temp, repeat_step)
k_copy = integrator.fsalfirst
u_tmp = uprev
for j in 1:2^(index-1)
k = _reshape(W\-_vec(dt_temp*k_copy), axes(uprev))
integrator.destats.nsolve += 1
u_tmp = u_tmp + k
k_copy = f(u_tmp, p, t+j*dt_temp)
end
T[index,1] = u_tmp
end
end
T[i,1] = u_tmp
# Richardson Extrapolation
for j in 2:i
T[i,j] = ((2^(j-1))*T[i,j-1] - T[i-1,j-1])/((2^(j-1)) - 1)

for i=2:max_order
for j=2:i
T[i,j] = ((2^(j-1))*T[i,j-1] - T[i-1,j-1])/((2^(j-1)) - 1)
end
end
end

integrator.destats.nf += 2^(max_order) - 1
integrator.dt = dt

Expand Down Expand Up @@ -391,9 +413,9 @@ function perform_step!(integrator,cache::ImplicitEulerExtrapolationConstantCache

# Use extrapolated value of u
integrator.u = T[cache.cur_order, cache.cur_order]
k = f(integrator.u, p, t+dt)
k_temp = f(integrator.u, p, t+dt)
integrator.destats.nf += 1
integrator.fsallast = k
integrator.fsallast = k_temp
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
end
Expand Down

0 comments on commit 280bddd

Please sign in to comment.