Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: change to new @muladd #89

Merged
merged 2 commits into from
Jul 31, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/initdt.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function ode_determine_initdt{tType,uType}(u0,t::tType,tdir,dtmax,abstol,reltol,internalnorm,prob::AbstractODEProblem{uType,tType,true},order)
@muladd function ode_determine_initdt{tType,uType}(u0,t::tType,tdir,dtmax,abstol,reltol,internalnorm,prob::AbstractODEProblem{uType,tType,true},order)
f = prob.f
f₀ = zeros(u0./t); f₁ = zeros(u0./t); u₁ = zeros(u0); sk = zeros(u0);
# Hack to make a generic u0 with no units, https://github.com/JuliaLang/julia/issues/22216
Expand Down Expand Up @@ -35,14 +35,14 @@ function ode_determine_initdt{tType,uType}(u0,t::tType,tdir,dtmax,abstol,reltol,

#@. u₁ = @muladd u0 + tdir*dt₀*f₀
@tight_loop_macros for i in uidx
@inbounds u₁[i] = u0[i] + tdir*dt₀*f₀[i]
@inbounds u₁[i] = u0[i] + (tdir*dt₀)*f₀[i]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably best to just hoist that out of the loop. Does @muladd work on this now? I think it had a problem with it before.

Copy link
Member Author

@devmotion devmotion Jul 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's probably the best. @muladd works on this now, it handles also products with more than two factors (should write a test for it...) - but since it is not a dot call all factors should be scalars. It does not fail for vectors of vectors since tdir*dt₀ is a scalar, but it also does not lead to a call of @llvm.fmuladd; see SciML/DiffEqBase.jl#57 (comment) and:

julia> @code_lowered muladd(1., 1., 1.)
CodeInfo(:(begin 
        nothing
        return (Base.muladd_float)(x, y, z)
    end))

julia> @code_lowered muladd(1., [1.], [1.])
CodeInfo(:(begin 
        nothing
        return x * y + z
    end))

julia> @code_lowered muladd([1.], [1.], [1.])
CodeInfo(:(begin 
        nothing
        return x * y + z
    end))

In the current implementation @muladd a*b*c+d is transformed to muladd(a, b*c, d) - so the first factor always ends up as the first argument to muladd and the product of all other factors builds the second argument. I don't know if this is better/more natural than muladd(a*b, c, d).

In this case @muladd produces

julia> macroexpand(:(@muladd @tight_loop_macros for i in uidx
           @inbounds u₁[i] = u0[i] + (tdir*dt₀)*f₀[i]
       end))
:(for i = uidx # REPL[6], line 2:
        begin 
            $(Expr(:inbounds, true))
            u₁[i] = (muladd)(tdir * dt₀, f₀[i], u0[i])
            $(Expr(:inbounds, :pop))
        end
    end)

and without brackets:

julia> macroexpand(:(@muladd @tight_loop_macros for i in uidx
           @inbounds u₁[i] = u0[i] + tdir*dt₀*f₀[i]
       end))
:(for i = uidx # REPL[7], line 2:
        begin 
            $(Expr(:inbounds, true))
            u₁[i] = (muladd)(tdir, dt₀ * f₀[i], u0[i])
            $(Expr(:inbounds, :pop))
        end
    end)

But of course the best is to move the multiplication of the first two factors completely out of the loop.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, good to see that works though!

end

f(t+tdir*dt₀,u₁,f₁)

#tmp = (f₁.-f₀)./(abstol+abs.(u0).*reltol)*tType(1)
@tight_loop_macros for (i,atol,rtol) in zip(uidx,Iterators.cycle(abstol),Iterators.cycle(reltol))
tmp[i] = (f₁[i]-f₀[i])./(atol+abs(u0[i])*rtol)*tType(1)
tmp[i] = (f₁[i]-f₀[i])/(atol+abs(u0[i])*rtol)*tType(1)
end

d₂ = internalnorm(tmp)/dt₀
Expand All @@ -56,9 +56,9 @@ function ode_determine_initdt{tType,uType}(u0,t::tType,tdir,dtmax,abstol,reltol,
dt = tdir*min(100dt₀,dt₁,tdir*dtmax)
end

function ode_determine_initdt{uType,tType}(u0::uType,t,tdir,dtmax,abstol,reltol,internalnorm,prob::AbstractODEProblem{uType,tType,false},order)
@muladd function ode_determine_initdt{uType,tType}(u0::uType,t,tdir,dtmax,abstol,reltol,internalnorm,prob::AbstractODEProblem{uType,tType,false},order)
f = prob.f
sk = abstol+abs.(u0).*reltol
sk = @. abstol+abs(u0)*reltol
d₀ = internalnorm(u0./sk)
f₀ = f(t,u0)
if any((isnan(x) for x in f₀))
Expand All @@ -73,9 +73,9 @@ function ode_determine_initdt{uType,tType}(u0::uType,t,tdir,dtmax,abstol,reltol,
dt₀ = tType((d₀/d₁)/100)
end
dt₀ = min(dt₀,tdir*dtmax)
u₁ = u0 + tdir*dt₀*f₀
u₁ = @. u0 + (tdir*dt₀)*f₀
f₁ = f(t+tdir*dt₀,u₁)
d₂ = internalnorm((f₁-f₀)./(abstol+abs.(u0).*reltol))/dt₀*tType(1)
d₂ = internalnorm(@. (f₁-f₀)/(abstol+abs(u0)*reltol))/dt₀*tType(1)
if max(d₁,d₂) <= T1(1//Int64(10)^(15))
dt₁ = max(tType(1//10^(6)),dt₀*1//10^(3))
else
Expand Down
107 changes: 31 additions & 76 deletions src/integrators/explicit_rk_integrator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
integrator.k[2] = integrator.fsallast
end

@inline function perform_step!(integrator,cache::ExplicitRKConstantCache,f=integrator.f)
@inline @muladd function perform_step!(integrator,cache::ExplicitRKConstantCache,f=integrator.f)
@unpack t,dt,uprev,u = integrator
@unpack A,c,α,αEEst,stages = cache
@unpack kk = cache
Expand All @@ -19,72 +19,28 @@ end
for i = 2:stages-1
utilde = zero(kk[1])
for j = 1:i-1
utilde = @muladd utilde + A[j,i]*kk[j]
utilde = @. utilde + A[j,i]*kk[j]
end
kk[i] = f(@muladd(t+c[i]*dt),@muladd(uprev+dt*utilde));
kk[i] = f(t+c[i]*dt, @. uprev + dt*utilde);
end
#Calc Last
utilde = zero(kk[1])
for j = 1:stages-1
utilde = @muladd utilde + A[j,end]*kk[j]
utilde = @. utilde + A[j,end]*kk[j]
end
kk[end] = f(@muladd(t+c[end]*dt),@muladd(uprev+dt*utilde)); integrator.fsallast = kk[end] # Uses fsallast as temp even if not fsal
kk[end] = f(t+c[end]*dt, @. uprev + dt*utilde); integrator.fsallast = kk[end] # Uses fsallast as temp even if not fsal
# Accumulate Result
utilde = α[1]*kk[1]
for i = 2:stages
utilde = @muladd utilde + α[i]*kk[i]
utilde = @. utilde + α[i]*kk[i]
end
u = @muladd uprev + dt*utilde
u = @. uprev + dt*utilde
if integrator.opts.adaptive
uEEst = αEEst[1]*kk[1]
for i = 2:stages
uEEst = @muladd uEEst + αEEst[i]*kk[i]
uEEst = @. uEEst + αEEst[i]*kk[i]
end
integrator.EEst = integrator.opts.internalnorm( dt*(utilde-uEEst)./@muladd(integrator.opts.abstol+max.(abs.(uprev),abs.(u)).*integrator.opts.reltol))
end
if isfsal(integrator.alg.tableau)
integrator.fsallast = kk[end]
else
integrator.fsallast = f(t+dt,u)
end
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
@pack integrator = t,dt,u
end

#=
@inline function perform_step!(integrator,cache::ExplicitRKConstantCache,f=integrator.f)
@unpack t,dt,uprev,u = integrator
@unpack A,c,α,αEEst,stages = cache
@unpack kk = cache
# Calc First
kk[1] = integrator.fsalfirst
# Calc Middle
for i = 2:stages-1
utilde = zero(kk[1])
for j = 1:i-1
utilde = @. @muladd utilde + A[j,i]*kk[j]
end
kk[i] = f(@muladd(t+c[i]*dt),@muladd(uprev+dt*utilde));
end
#Calc Last
utilde = zero(kk[1])
for j = 1:stages-1
utilde = @. @muladd utilde + A[j,end]*kk[j]
end
kk[end] = f(@muladd(t+c[end]*dt),@muladd(uprev+dt*utilde)); integrator.fsallast = kk[end] # Uses fsallast as temp even if not fsal
# Accumulate Result
utilde = α[1]*kk[1]
for i = 2:stages
utilde = @. @muladd utilde + α[i]*kk[i]
end
u = @. @muladd uprev + dt*utilde
if integrator.opts.adaptive
uEEst = αEEst[1]*kk[1]
for i = 2:stages
uEEst = @. @muladd uEEst + αEEst[i]*kk[i]
end
tmp = @. dt*(utilde-uEEst)./@muladd(integrator.opts.abstol+max.(abs.(uprev),abs.(u)).*integrator.opts.reltol)
tmp = @. dt*(utilde-uEEst)/(integrator.opts.abstol+max(abs(uprev),abs(u))*integrator.opts.reltol)
integrator.EEst = integrator.opts.internalnorm(tmp)
end
if isfsal(integrator.alg.tableau)
Expand All @@ -96,7 +52,6 @@ end
integrator.k[2] = integrator.fsallast
@pack integrator = t,dt,u
end
=#

@inline function initialize!(integrator,cache::ExplicitRKCache,f=integrator.f)
integrator.kshortsize = 2
Expand All @@ -109,40 +64,40 @@ end
end

#=
@inline function perform_step!(integrator,cache::ExplicitRKCache,f=integrator.f)
@inline @muladd function perform_step!(integrator,cache::ExplicitRKCache,f=integrator.f)
@unpack t,dt,uprev,u,k = integrator
@unpack A,c,α,αEEst,stages = cache.tab
@unpack kk,utilde,tmp,atmp,uEEst = cache
# Middle
for i = 2:stages-1
@. utilde = zero(kk[1][1])
for j = 1:i-1
@. utilde = @muladd utilde + A[j,i]*kk[j]
@. utilde = utilde + A[j,i]*kk[j]
end
@. tmp = @muladd uprev+dt*utilde
f(@muladd(t+c[i]*dt),tmp,kk[i])
@. tmp = uprev+dt*utilde
f(t+c[i]*dt,tmp,kk[i])
end
#Last
@. utilde = zero(kk[1][1])
for j = 1:stages-1
@. utilde = @muladd utilde + A[j,end]*kk[j]
@. utilde = utilde + A[j,end]*kk[j]
end
@. u = @muladd uprev+dt*utilde
f(@muladd(t+c[end]*dt),u,kk[end]) #fsallast is tmp even if not fsal
@. u = uprev+dt*utilde
f(t+c[end]*dt),u,kk[end]) #fsallast is tmp even if not fsal
#Accumulate
if !isfsal(integrator.alg.tableau)
@. utilde = α[1]*kk[1]
for i = 2:stages
@. utilde = @muladd utilde + α[i]*kk[i]
@. utilde = utilde + α[i]*kk[i]
end
@. u = @muladd uprev + dt*utilde
@. u = uprev + dt*utilde
end
if integrator.opts.adaptive
@. uEEst = αEEst[1]*kk[1]
for i = 2:stages
@. uEEst = @muladd uEEst + αEEst[i]*kk[i]
@. uEEst = uEEst + αEEst[i]*kk[i]
end
@. atmp = (dt*(utilde-uEEst)/@muladd(integrator.opts.abstol+max(abs(uprev),abs(u))*integrator.opts.reltol))
@. atmp = (dt*(utilde-uEEst)/(integrator.opts.abstol+max(abs(uprev),abs(u))*integrator.opts.reltol))
integrator.EEst = integrator.opts.internalnorm(atmp)
end
if !isfsal(integrator.alg.tableau)
Expand All @@ -152,7 +107,7 @@ end
end
=#

@inline function perform_step!(integrator,cache::ExplicitRKCache,f=integrator.f)
@inline @muladd function perform_step!(integrator,cache::ExplicitRKCache,f=integrator.f)
@unpack t,dt,uprev,u,k = integrator
uidx = eachindex(integrator.uprev)
@unpack A,c,α,αEEst,stages = cache.tab
Expand All @@ -164,39 +119,39 @@ end
end
for j = 1:i-1
@tight_loop_macros for l in uidx
@inbounds utilde[l] = @muladd utilde[l] + A[j,i]*kk[j][l]
@inbounds utilde[l] = utilde[l] + A[j,i]*kk[j][l]
end
end
@tight_loop_macros for l in uidx
@inbounds tmp[l] = @muladd uprev[l]+dt*utilde[l]
@inbounds tmp[l] = uprev[l]+dt*utilde[l]
end
f(@muladd(t+c[i]*dt),tmp,kk[i])
f(t+c[i]*dt,tmp,kk[i])
end
#Last
@tight_loop_macros for l in uidx
@inbounds utilde[l] = zero(kk[1][1])
end
for j = 1:stages-1
@tight_loop_macros for l in uidx
@inbounds utilde[l] = @muladd utilde[l] + A[j,end]*kk[j][l]
@inbounds utilde[l] = utilde[l] + A[j,end]*kk[j][l]
end
end
@tight_loop_macros for l in uidx
@inbounds u[l] = @muladd uprev[l]+dt*utilde[l]
@inbounds u[l] = uprev[l]+dt*utilde[l]
end
f(@muladd(t+c[end]*dt),u,kk[end]) #fsallast is tmp even if not fsal
f(t+c[end]*dt,u,kk[end]) #fsallast is tmp even if not fsal
#Accumulate
if !isfsal(integrator.alg.tableau)
@tight_loop_macros for i in uidx
@inbounds utilde[i] = α[1]*kk[1][i]
end
for i = 2:stages
@tight_loop_macros for l in uidx
@inbounds utilde[l] = @muladd utilde[l] + α[i]*kk[i][l]
@inbounds utilde[l] = utilde[l] + α[i]*kk[i][l]
end
end
@tight_loop_macros for i in uidx
@inbounds u[i] = @muladd uprev[i] + dt*utilde[i]
@inbounds u[i] = uprev[i] + dt*utilde[i]
end
end
if integrator.opts.adaptive
Expand All @@ -205,11 +160,11 @@ end
end
for i = 2:stages
@tight_loop_macros for j in uidx
@inbounds uEEst[j] = @muladd uEEst[j] + αEEst[i]*kk[i][j]
@inbounds uEEst[j] = uEEst[j] + αEEst[i]*kk[i][j]
end
end
@tight_loop_macros for (i,atol,rtol) in zip(uidx,Iterators.cycle(integrator.opts.abstol),Iterators.cycle(integrator.opts.reltol))
@inbounds atmp[i] = (dt*(utilde[i]-uEEst[i])./@muladd(atol+max(abs(uprev[i]),abs(u[i])).*rtol))
@inbounds atmp[i] = dt*(utilde[i]-uEEst[i])/(atol+max(abs(uprev[i]),abs(u[i]))*rtol)
end
integrator.EEst = integrator.opts.internalnorm(atmp)
end
Expand Down
24 changes: 12 additions & 12 deletions src/integrators/exponential_rk_integrators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
integrator.k[2] = zero(integrator.fsalfirst)
end

@inline function perform_step!(integrator,cache::LawsonEulerConstantCache,f=integrator.f)
@inline @muladd function perform_step!(integrator,cache::LawsonEulerConstantCache,f=integrator.f)
@unpack t,dt,uprev,u,k = integrator
rtmp = integrator.fsalfirst
A = f[1]
u = expm(dt*A)*(uprev + dt*rtmp)
u = expm(dt*A)*(@. uprev + dt*rtmp)
rtmp = f[2](t+dt,u)
k = A*u + rtmp # For the interpolation, needs k at the updated point
k = A*u .+ rtmp # For the interpolation, needs k at the updated point
integrator.fsallast = rtmp
integrator.k[1] = integrator.fsalfirst # this is wrong, since it's just rtmp. Should fsal this value though
integrator.k[2] = k
Expand All @@ -34,19 +34,19 @@ end
A = f[1]
A_mul_B!(cache.k,A,integrator.u)
f[2](integrator.t,integrator.uprev,rtmp) # For the interpolation, needs k at the updated point
integrator.fsalfirst .= cache.k .+ rtmp
@. integrator.fsalfirst = cache.k + rtmp
end

@inline function perform_step!(integrator,cache::LawsonEulerCache,f=integrator.f)
@inline @muladd function perform_step!(integrator,cache::LawsonEulerCache,f=integrator.f)
@unpack t,dt,uprev,u = integrator
@unpack k,rtmp,tmp = cache
A = f[1]
M = expm(dt*A)
tmp .= uprev .+ dt.*integrator.fsalfirst
@. tmp = uprev + dt*integrator.fsalfirst
A_mul_B!(u,M,tmp)
A_mul_B!(tmp,A,u)
f[2](t+dt,u,rtmp)
k = tmp .+ rtmp
@. k = tmp + rtmp
@pack integrator = t,dt,u
end

Expand All @@ -66,9 +66,9 @@ end
@unpack t,dt,uprev,u,k = integrator
rtmp = integrator.fsalfirst
A = f[1]
u = uprev + ((expm(dt*A)-I)/A)*(A*uprev + rtmp)
u = uprev .+ ((expm(dt*A)-I)/A)*(A*uprev .+ rtmp)
rtmp = f[2](t+dt,u)
k = A*u + rtmp # For the interpolation, needs k at the updated point
k = A*u .+ rtmp # For the interpolation, needs k at the updated point
integrator.fsallast = rtmp
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = k
Expand All @@ -86,7 +86,7 @@ end
A = f[1](integrator.t,integrator.u,rtmp)
A_mul_B!(cache.k,A,integrator.u)
f[2](integrator.t,integrator.uprev,rtmp) # For the interpolation, needs k at the updated point
integrator.fsalfirst .= cache.k .+ rtmp
@. integrator.fsalfirst = cache.k + rtmp
end

@inline function perform_step!(integrator,cache::NorsettEulerCache,f=integrator.f)
Expand All @@ -97,9 +97,9 @@ end
A_mul_B!(tmp,A,uprev)
tmp .+= rtmp
A_mul_B!(rtmp,M,tmp)
u .= uprev .+ rtmp
@. u = uprev + rtmp
A_mul_B!(tmp,A,u)
f[2](t+dt,u,rtmp)
k .= tmp .+ rtmp
@. k = tmp + rtmp
@pack integrator = t,dt,u
end
Loading