diff --git a/src/eval.jl b/src/eval.jl index c85e35f..9e69988 100644 --- a/src/eval.jl +++ b/src/eval.jl @@ -13,22 +13,22 @@ by `c[i1 + (i-1)*Δi]`, i.e. `i1` is the starting index and product of `size(c)[1:dim]`. The interpolation point `x` should lie within [-1,+1] in each coordinate. """ -function interpolate(x::SVector{N}, c::Array{<:Any,N}, ::Val{dim}, i1, len) where {N,dim} +@fastmath function interpolate(x::SVector{N}, c::Array{<:Any,N}, ::Val{dim}, i1, len) where {N,dim} n = size(c,dim) @inbounds xd = x[dim] if dim == 1 c₁ = c[i1] if n ≤ 2 n == 1 && return c₁ + one(xd) * zero(c₁) - return c₁ + xd*c[i1] + return muladd(xd, c[i1], c₁) end - @inbounds bₖ = c[i1+(n-2)] + 2xd*c[i1+(n-1)] + @inbounds bₖ = muladd(2xd, c[i1+(n-1)], c[i1+(n-2)]) @inbounds bₖ₊₁ = oftype(bₖ, c[i1+(n-1)]) for j = n-3:-1:1 - @inbounds bⱼ = c[i1+j] + 2xd*bₖ - bₖ₊₁ + @inbounds bⱼ = muladd(2xd, bₖ, c[i1+j]) - bₖ₊₁ bₖ, bₖ₊₁ = bⱼ, bₖ end - return c₁ + xd*bₖ - bₖ₊₁ + return muladd(xd, bₖ, c₁) - bₖ₊₁ else Δi = len ÷ n # column-major stride of current dimension @@ -44,14 +44,14 @@ function interpolate(x::SVector{N}, c::Array{<:Any,N}, ::Val{dim}, i1, len) wher end cₙ₋₁ = interpolate(x, c, dim′, i1+(n-2)*Δi, Δi) cₙ = interpolate(x, c, dim′, i1+(n-1)*Δi, Δi) - bₖ = cₙ₋₁ + 2xd*cₙ + bₖ = muladd(2xd, cₙ, cₙ₋₁) bₖ₊₁ = oftype(bₖ, cₙ) for j = n-3:-1:1 cⱼ = interpolate(x, c, dim′, i1+j*Δi, Δi) - bⱼ = cⱼ + 2xd*bₖ - bₖ₊₁ + bⱼ = muladd(2xd, bₖ, cⱼ) - bₖ₊₁ bₖ, bₖ₊₁ = bⱼ, bₖ end - return c₁ + xd*bₖ - bₖ₊₁ + return muladd(xd, bₖ, c₁) - bₖ₊₁ end end @@ -60,7 +60,7 @@ end Evaluate the Chebyshev polynomial given by `interp` at the point `x`. """ -function (interp::ChebPoly{N})(x::SVector{N,<:Real}) where {N} +@fastmath function (interp::ChebPoly{N})(x::SVector{N,<:Real}) where {N} x0 = @. (x - interp.lb) * 2 / (interp.ub - interp.lb) - 1 all(abs.(x0) .≤ 1) || throw(ArgumentError("$x not in domain")) return interpolate(x0, interp.coefs, Val{N}(), 1, length(interp.coefs)) @@ -76,28 +76,29 @@ end Similar to `interpolate` above, but returns a tuple `(v,J)` of the interpolated value `v` and the Jacobian `J` with respect to `x[1:dim]`. """ -function Jinterpolate(x::SVector{N}, c::Array{<:Any,N}, ::Val{dim}, i1, len) where {N,dim} +@fastmath function Jinterpolate(x::SVector{N}, c::Array{<:Any,N}, ::Val{dim}, i1, len) where {N,dim} n = size(c,dim) @inbounds xd = x[dim] if dim == 1 c₁ = c[i1] if n ≤ 2 n == 1 && return c₁ + one(xd) * zero(c₁), hcat(SVector(zero(c₁) / oneunit(xd))) - return c₁ + xd*c[i1], hcat(SVector(c[i1])) + return muladd(xd, c[i1], c₁), hcat(SVector(c[i1])) end @inbounds cₙ₋₁ = c[i1+(n-2)] @inbounds cₙ = c[i1+(n-1)] - bₖ = cₙ₋₁ + xd*(bₖ′ = 2cₙ) + bₖ′ = 2cₙ + bₖ = muladd(xd, bₖ′, cₙ₋₁) bₖ₊₁ = oftype(bₖ, cₙ) bₖ₊₁′ = zero(bₖ₊₁) for j = n-3:-1:1 @inbounds cⱼ = c[i1+j] - bⱼ = cⱼ + xd*(2bₖ) - bₖ₊₁ - bⱼ′ = xd*(2bₖ′) + (2bₖ) - bₖ₊₁′ + bⱼ = muladd(2xd, bₖ, cⱼ) - bₖ₊₁ + bⱼ′ = muladd(2xd, bₖ′, 2bₖ) - bₖ₊₁′ bₖ, bₖ₊₁ = bⱼ, bₖ bₖ′, bₖ₊₁′ = bⱼ′, bₖ′ end - return c₁ + xd*bₖ - bₖ₊₁, hcat(SVector(bₖ + xd*bₖ′ - bₖ₊₁′)) + return muladd(xd, bₖ, c₁) - bₖ₊₁, hcat(SVector(muladd(xd, bₖ′, bₖ) - bₖ₊₁′)) else Δi = len ÷ n # column-major stride of current dimension @@ -109,25 +110,26 @@ function Jinterpolate(x::SVector{N}, c::Array{<:Any,N}, ::Val{dim}, i1, len) whe if n ≤ 2 n == 1 && return c₁ + one(xd) * zero(c₁), hcat(Jc₁, SVector(zero(c₁) / oneunit(xd))) c₂,Jc₂ = Jinterpolate(x, c, dim′, i1+Δi, Δi) - return c₁ + xd*c₂, hcat(Jc₁ + xd*Jc₂, SVector(c[i1])) + return muladd(xd, c₂, c₁), hcat(muladd(xd, Jc₂, Jc₁), SVector(c[i1])) end cₙ₋₁,Jcₙ₋₁ = Jinterpolate(x, c, dim′, i1+(n-2)*Δi, Δi) cₙ,Jcₙ = Jinterpolate(x, c, dim′, i1+(n-1)*Δi, Δi) - bₖ = cₙ₋₁ + xd*(bₖ′ = 2cₙ) - Jbₖ = Jcₙ₋₁ + 2xd*Jcₙ + bₖ′ = 2cₙ + bₖ = muladd(xd, bₖ′, cₙ₋₁) + Jbₖ = muladd(2xd, Jcₙ, Jcₙ₋₁) bₖ₊₁ = oftype(bₖ, cₙ) bₖ₊₁′ = zero(bₖ₊₁) Jbₖ₊₁ = oftype(Jbₖ, Jcₙ) for j = n-3:-1:1 cⱼ,Jcⱼ = Jinterpolate(x, c, dim′, i1+j*Δi, Δi) - bⱼ = cⱼ + xd*(2bₖ) - bₖ₊₁ - bⱼ′ = xd*(2bₖ′) + (2bₖ) - bₖ₊₁′ - Jbⱼ = Jcⱼ + xd*(2Jbₖ) - Jbₖ₊₁ + bⱼ = muladd(2xd, bₖ, cⱼ) - bₖ₊₁ + bⱼ′ = muladd(2xd, bₖ′, 2bₖ) - bₖ₊₁′ + Jbⱼ = muladd(2xd, Jbₖ, Jcⱼ) - Jbₖ₊₁ bₖ, bₖ₊₁ = bⱼ, bₖ bₖ′, bₖ₊₁′ = bⱼ′, bₖ′ Jbₖ, Jbₖ₊₁ = Jbⱼ, Jbₖ end - return c₁ + xd*bₖ - bₖ₊₁, hcat(Jc₁ + xd*Jbₖ - Jbₖ₊₁, SVector(bₖ + xd*bₖ′ - bₖ₊₁′)) + return muladd(xd, bₖ, c₁) - bₖ₊₁, hcat(muladd(xd, Jbₖ, Jc₁) - Jbₖ₊₁, SVector(muladd(xd, bₖ′, bₖ) - bₖ₊₁′)) end end