Skip to content

Commit

Permalink
speedups
Browse files Browse the repository at this point in the history
  • Loading branch information
stevengj committed Sep 10, 2021
1 parent fab7e9d commit 3340207
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions src/eval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,22 @@ by `c[i1 + (i-1)*Δi]`, i.e. `i1` is the starting index and
product of `size(c)[1:dim]`. The interpolation point `x`
should lie within [-1,+1] in each coordinate.
"""
function interpolate(x::SVector{N}, c::Array{<:Any,N}, ::Val{dim}, i1, len) where {N,dim}
@fastmath function interpolate(x::SVector{N}, c::Array{<:Any,N}, ::Val{dim}, i1, len) where {N,dim}
n = size(c,dim)
@inbounds xd = x[dim]
if dim == 1
c₁ = c[i1]
if n 2
n == 1 && return c₁ + one(xd) * zero(c₁)
return c₁ + xd*c[i1]
return muladd(xd, c[i1], c₁)
end
@inbounds bₖ = c[i1+(n-2)] + 2xd*c[i1+(n-1)]
@inbounds bₖ = muladd(2xd, c[i1+(n-1)], c[i1+(n-2)])
@inbounds bₖ₊₁ = oftype(bₖ, c[i1+(n-1)])
for j = n-3:-1:1
@inbounds bⱼ = c[i1+j] + 2xd*bₖ - bₖ₊₁
@inbounds bⱼ = muladd(2xd, bₖ, c[i1+j]) - bₖ₊₁
bₖ, bₖ₊₁ = bⱼ, bₖ
end
return c₁ + xd*bₖ - bₖ₊₁
return muladd(xd, bₖ, c₁) - bₖ₊₁
else
Δi = len ÷ n # column-major stride of current dimension

Expand All @@ -44,14 +44,14 @@ function interpolate(x::SVector{N}, c::Array{<:Any,N}, ::Val{dim}, i1, len) wher
end
cₙ₋₁ = interpolate(x, c, dim′, i1+(n-2)*Δi, Δi)
cₙ = interpolate(x, c, dim′, i1+(n-1)*Δi, Δi)
bₖ = cₙ₋₁ + 2xd*cₙ
bₖ = muladd(2xd, cₙ, cₙ₋₁)
bₖ₊₁ = oftype(bₖ, cₙ)
for j = n-3:-1:1
cⱼ = interpolate(x, c, dim′, i1+j*Δi, Δi)
bⱼ = cⱼ + 2xd*bₖ - bₖ₊₁
bⱼ = muladd(2xd, bₖ, cⱼ) - bₖ₊₁
bₖ, bₖ₊₁ = bⱼ, bₖ
end
return c₁ + xd*bₖ - bₖ₊₁
return muladd(xd, bₖ, c₁) - bₖ₊₁
end
end

Expand All @@ -60,7 +60,7 @@ end
Evaluate the Chebyshev polynomial given by `interp` at the point `x`.
"""
function (interp::ChebPoly{N})(x::SVector{N,<:Real}) where {N}
@fastmath function (interp::ChebPoly{N})(x::SVector{N,<:Real}) where {N}
x0 = @. (x - interp.lb) * 2 / (interp.ub - interp.lb) - 1
all(abs.(x0) .≤ 1) || throw(ArgumentError("$x not in domain"))
return interpolate(x0, interp.coefs, Val{N}(), 1, length(interp.coefs))
Expand All @@ -76,28 +76,29 @@ end
Similar to `interpolate` above, but returns a tuple `(v,J)` of the
interpolated value `v` and the Jacobian `J` with respect to `x[1:dim]`.
"""
function Jinterpolate(x::SVector{N}, c::Array{<:Any,N}, ::Val{dim}, i1, len) where {N,dim}
@fastmath function Jinterpolate(x::SVector{N}, c::Array{<:Any,N}, ::Val{dim}, i1, len) where {N,dim}
n = size(c,dim)
@inbounds xd = x[dim]
if dim == 1
c₁ = c[i1]
if n 2
n == 1 && return c₁ + one(xd) * zero(c₁), hcat(SVector(zero(c₁) / oneunit(xd)))
return c₁ + xd*c[i1], hcat(SVector(c[i1]))
return muladd(xd, c[i1], c₁), hcat(SVector(c[i1]))
end
@inbounds cₙ₋₁ = c[i1+(n-2)]
@inbounds cₙ = c[i1+(n-1)]
bₖ = cₙ₋₁ + xd*(bₖ′ = 2cₙ)
bₖ′ = 2cₙ
bₖ = muladd(xd, bₖ′, cₙ₋₁)
bₖ₊₁ = oftype(bₖ, cₙ)
bₖ₊₁′ = zero(bₖ₊₁)
for j = n-3:-1:1
@inbounds cⱼ = c[i1+j]
bⱼ = cⱼ + xd*(2bₖ) - bₖ₊₁
bⱼ′ = xd*(2bₖ′) + (2bₖ) - bₖ₊₁′
bⱼ = muladd(2xd, bₖ, cⱼ) - bₖ₊₁
bⱼ′ = muladd(2xd, bₖ′, 2bₖ) - bₖ₊₁′
bₖ, bₖ₊₁ = bⱼ, bₖ
bₖ′, bₖ₊₁′ = bⱼ′, bₖ′
end
return c₁ + xd*bₖ - bₖ₊₁, hcat(SVector(bₖ + xd*bₖ′ - bₖ₊₁′))
return muladd(xd, bₖ, c₁) - bₖ₊₁, hcat(SVector(muladd(xd, bₖ′, bₖ) - bₖ₊₁′))
else
Δi = len ÷ n # column-major stride of current dimension

Expand All @@ -109,25 +110,26 @@ function Jinterpolate(x::SVector{N}, c::Array{<:Any,N}, ::Val{dim}, i1, len) whe
if n 2
n == 1 && return c₁ + one(xd) * zero(c₁), hcat(Jc₁, SVector(zero(c₁) / oneunit(xd)))
c₂,Jc₂ = Jinterpolate(x, c, dim′, i1+Δi, Δi)
return c₁ + xd*c₂, hcat(Jc₁ + xd*Jc₂, SVector(c[i1]))
return muladd(xd, c₂, c₁), hcat(muladd(xd, Jc₂, Jc₁), SVector(c[i1]))
end
cₙ₋₁,Jcₙ₋₁ = Jinterpolate(x, c, dim′, i1+(n-2)*Δi, Δi)
cₙ,Jcₙ = Jinterpolate(x, c, dim′, i1+(n-1)*Δi, Δi)
bₖ = cₙ₋₁ + xd*(bₖ′ = 2cₙ)
Jbₖ = Jcₙ₋₁ + 2xd*Jcₙ
bₖ′ = 2cₙ
bₖ = muladd(xd, bₖ′, cₙ₋₁)
Jbₖ = muladd(2xd, Jcₙ, Jcₙ₋₁)
bₖ₊₁ = oftype(bₖ, cₙ)
bₖ₊₁′ = zero(bₖ₊₁)
Jbₖ₊₁ = oftype(Jbₖ, Jcₙ)
for j = n-3:-1:1
cⱼ,Jcⱼ = Jinterpolate(x, c, dim′, i1+j*Δi, Δi)
bⱼ = cⱼ + xd*(2bₖ) - bₖ₊₁
bⱼ′ = xd*(2bₖ′) + (2bₖ) - bₖ₊₁′
Jbⱼ = Jcⱼ + xd*(2Jbₖ) - Jbₖ₊₁
bⱼ = muladd(2xd, bₖ, cⱼ) - bₖ₊₁
bⱼ′ = muladd(2xd, bₖ′, 2bₖ) - bₖ₊₁′
Jbⱼ = muladd(2xd, Jbₖ, Jcⱼ) - Jbₖ₊₁
bₖ, bₖ₊₁ = bⱼ, bₖ
bₖ′, bₖ₊₁′ = bⱼ′, bₖ′
Jbₖ, Jbₖ₊₁ = Jbⱼ, Jbₖ
end
return c₁ + xd*bₖ - bₖ₊₁, hcat(Jc₁ + xd*Jbₖ - Jbₖ₊₁, SVector(bₖ + xd*bₖ′ - bₖ₊₁′))
return muladd(xd, bₖ, c₁) - bₖ₊₁, hcat(muladd(xd, Jbₖ, Jc₁) - Jbₖ₊₁, SVector(muladd(xd, bₖ′, bₖ) - bₖ₊₁′))
end
end

Expand Down

0 comments on commit 3340207

Please sign in to comment.