Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: specialize on interpolations in the different Values structs #339

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 35 additions & 31 deletions src/FEValues/cell_values.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Defines CellScalarValues and CellVectorValues and common methods
"""
CellScalarValues([::Type{T}], quad_rule::QuadratureRule, func_interpol::Interpolation, [geom_interpol::Interpolation])
CellVectorValues([::Type{T}], quad_rule::QuadratureRule, func_interpol::Interpolation, [geom_interpol::Interpolation])
CellScalarValues([::Type{T}], quad_rule::QuadratureRule, func_interp::Interpolation, [geo_interp::Interpolation])
CellVectorValues([::Type{T}], quad_rule::QuadratureRule, func_interp::Interpolation, [geo_interp::Interpolation])

A `CellValues` object facilitates the process of evaluating values of shape functions, gradients of shape functions,
values of nodal functions, gradients and divergences of nodal functions etc. in the finite element cell. There are
Expand All @@ -12,8 +12,8 @@ utilizes scalar shape functions and `CellVectorValues` utilizes vectorial shape
**Arguments:**
* `T`: an optional argument (default to `Float64`) to determine the type the internal data is stored as.
* `quad_rule`: an instance of a [`QuadratureRule`](@ref)
* `func_interpol`: an instance of an [`Interpolation`](@ref) used to interpolate the approximated function
* `geom_interpol`: an optional instance of a [`Interpolation`](@ref) which is used to interpolate the geometry
* `func_interp`: an instance of an [`Interpolation`](@ref) used to interpolate the approximated function
* `geo_interp`: an optional instance of a [`Interpolation`](@ref) which is used to interpolate the geometry

**Common methods:**
* [`reinit!`](@ref)
Expand All @@ -34,90 +34,94 @@ utilizes scalar shape functions and `CellVectorValues` utilizes vectorial shape
CellValues

# CellScalarValues
struct CellScalarValues{dim,T<:Real,refshape<:AbstractRefShape} <: CellValues{dim,T,refshape}
struct CellScalarValues{dim,T<:Real,refshape<:AbstractRefShape,FI,GI} <: CellValues{dim,T,refshape,FI,GI}
N::Matrix{T}
dNdx::Matrix{Vec{dim,T}}
dNdξ::Matrix{Vec{dim,T}}
detJdV::Vector{T}
M::Matrix{T}
dMdξ::Matrix{Vec{dim,T}}
qr_weights::Vector{T}
func_interp::FI
geo_interp::GI
end

function CellScalarValues(quad_rule::QuadratureRule, func_interpol::Interpolation,
geom_interpol::Interpolation=func_interpol)
CellScalarValues(Float64, quad_rule, func_interpol, geom_interpol)
function CellScalarValues(quad_rule::QuadratureRule, func_interp::Interpolation,
geo_interp::Interpolation=func_interp)
CellScalarValues(Float64, quad_rule, func_interp, geo_interp)
end

function CellScalarValues(::Type{T}, quad_rule::QuadratureRule{dim,shape}, func_interpol::Interpolation,
geom_interpol::Interpolation=func_interpol) where {dim,T,shape<:AbstractRefShape}
function CellScalarValues(::Type{T}, quad_rule::QuadratureRule{dim,shape,}, func_interp::Interpolation,
geo_interp::Interpolation=func_interp) where {dim,T,shape<:AbstractRefShape}

@assert getdim(func_interpol) == getdim(geom_interpol)
@assert getrefshape(func_interpol) == getrefshape(geom_interpol) == shape
@assert getdim(func_interp) == getdim(geo_interp)
@assert getrefshape(func_interp) == getrefshape(geo_interp) == shape
n_qpoints = length(getweights(quad_rule))

# Function interpolation
n_func_basefuncs = getnbasefunctions(func_interpol)
n_func_basefuncs = getnbasefunctions(func_interp)
N = fill(zero(T) * T(NaN), n_func_basefuncs, n_qpoints)
dNdx = fill(zero(Vec{dim,T}) * T(NaN), n_func_basefuncs, n_qpoints)
dNdξ = fill(zero(Vec{dim,T}) * T(NaN), n_func_basefuncs, n_qpoints)

# Geometry interpolation
n_geom_basefuncs = getnbasefunctions(geom_interpol)
n_geom_basefuncs = getnbasefunctions(geo_interp)
M = fill(zero(T) * T(NaN), n_geom_basefuncs, n_qpoints)
dMdξ = fill(zero(Vec{dim,T}) * T(NaN), n_geom_basefuncs, n_qpoints)

for (qp, ξ) in enumerate(quad_rule.points)
for i in 1:n_func_basefuncs
dNdξ[i, qp], N[i, qp] = gradient(ξ -> value(func_interpol, i, ξ), ξ, :all)
dNdξ[i, qp], N[i, qp] = gradient(ξ -> value(func_interp, i, ξ), ξ, :all)
end
for i in 1:n_geom_basefuncs
dMdξ[i, qp], M[i, qp] = gradient(ξ -> value(geom_interpol, i, ξ), ξ, :all)
dMdξ[i, qp], M[i, qp] = gradient(ξ -> value(geo_interp, i, ξ), ξ, :all)
end
end

detJdV = fill(T(NaN), n_qpoints)

CellScalarValues{dim,T,shape}(N, dNdx, dNdξ, detJdV, M, dMdξ, quad_rule.weights)
CellScalarValues{dim,T,shape,typeof(func_interp),typeof(geo_interp)}(N, dNdx, dNdξ, detJdV, M, dMdξ, quad_rule.weights, func_interp, geo_interp)
end

# CellVectorValues
struct CellVectorValues{dim,T<:Real,refshape<:AbstractRefShape,M} <: CellValues{dim,T,refshape}
struct CellVectorValues{dim,T<:Real,refshape<:AbstractRefShape,FI,GI,M} <: CellValues{dim,T,refshape,FI,GI}
N::Matrix{Vec{dim,T}}
dNdx::Matrix{Tensor{2,dim,T,M}}
dNdξ::Matrix{Tensor{2,dim,T,M}}
detJdV::Vector{T}
M::Matrix{T}
dMdξ::Matrix{Vec{dim,T}}
qr_weights::Vector{T}
func_interp::FI
geo_interp::GI
end

function CellVectorValues(quad_rule::QuadratureRule, func_interpol::Interpolation, geom_interpol::Interpolation=func_interpol)
CellVectorValues(Float64, quad_rule, func_interpol, geom_interpol)
function CellVectorValues(quad_rule::QuadratureRule, func_interp::Interpolation, geo_interp::Interpolation=func_interp)
CellVectorValues(Float64, quad_rule, func_interp, geo_interp)
end

function CellVectorValues(::Type{T}, quad_rule::QuadratureRule{dim,shape}, func_interpol::Interpolation,
geom_interpol::Interpolation=func_interpol) where {dim,T,shape<:AbstractRefShape}
function CellVectorValues(::Type{T}, quad_rule::QuadratureRule{dim,shape}, func_interp::Interpolation,
geo_interp::Interpolation=func_interp) where {dim,T,shape<:AbstractRefShape}

@assert getdim(func_interpol) == getdim(geom_interpol)
@assert getrefshape(func_interpol) == getrefshape(geom_interpol) == shape
@assert getdim(func_interp) == getdim(geo_interp)
@assert getrefshape(func_interp) == getrefshape(geo_interp) == shape
n_qpoints = length(getweights(quad_rule))

# Function interpolation
n_func_basefuncs = getnbasefunctions(func_interpol) * dim
n_func_basefuncs = getnbasefunctions(func_interp) * dim
N = fill(zero(Vec{dim,T}) * T(NaN), n_func_basefuncs, n_qpoints)
dNdx = fill(zero(Tensor{2,dim,T}) * T(NaN), n_func_basefuncs, n_qpoints)
dNdξ = fill(zero(Tensor{2,dim,T}) * T(NaN), n_func_basefuncs, n_qpoints)

# Geometry interpolation
n_geom_basefuncs = getnbasefunctions(geom_interpol)
n_geom_basefuncs = getnbasefunctions(geo_interp)
M = fill(zero(T) * T(NaN), n_geom_basefuncs, n_qpoints)
dMdξ = fill(zero(Vec{dim,T}) * T(NaN), n_geom_basefuncs, n_qpoints)

for (qp, ξ) in enumerate(quad_rule.points)
basefunc_count = 1
for basefunc in 1:getnbasefunctions(func_interpol)
dNdξ_temp, N_temp = gradient(ξ -> value(func_interpol, basefunc, ξ), ξ, :all)
for basefunc in 1:getnbasefunctions(func_interp)
dNdξ_temp, N_temp = gradient(ξ -> value(func_interp, basefunc, ξ), ξ, :all)
for comp in 1:dim
N_comp = zeros(T, dim)
N_comp[comp] = N_temp
Expand All @@ -130,14 +134,14 @@ function CellVectorValues(::Type{T}, quad_rule::QuadratureRule{dim,shape}, func_
end
end
for basefunc in 1:n_geom_basefuncs
dMdξ[basefunc, qp], M[basefunc, qp] = gradient(ξ -> value(geom_interpol, basefunc, ξ), ξ, :all)
dMdξ[basefunc, qp], M[basefunc, qp] = gradient(ξ -> value(geo_interp, basefunc, ξ), ξ, :all)
end
end

detJdV = fill(T(NaN), n_qpoints)
MM = Tensors.n_components(Tensors.get_base(eltype(dNdx)))

CellVectorValues{dim,T,shape,MM}(N, dNdx, dNdξ, detJdV, M, dMdξ, quad_rule.weights)
CellVectorValues{dim,T,shape,typeof(func_interp),typeof(geo_interp),MM}(N, dNdx, dNdξ, detJdV, M, dMdξ, quad_rule.weights, func_interp, geo_interp)
end

function reinit!(cv::CellValues{dim}, x::AbstractVector{Vec{dim,T}}) where {dim,T}
Expand All @@ -154,7 +158,7 @@ function reinit!(cv::CellValues{dim}, x::AbstractVector{Vec{dim,T}}) where {dim,
fecv_J += x[j] ⊗ cv.dMdξ[j, i]
end
detJ = det(fecv_J)
detJ > 0.0 || throw(ArgumentError("det(J) is not positive: det(J) = $(detJ)"))
detJ > 0.0 || throw_detJ_not_pos(detJ)
cv.detJdV[i] = detJ * w
Jinv = inv(fecv_J)
for j in 1:n_func_basefuncs
Expand Down
20 changes: 10 additions & 10 deletions src/FEValues/common_values.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

using Base: @propagate_inbounds

const ScalarValues{dim,T,shape} = Union{CellScalarValues{dim,T,shape},FaceScalarValues{dim,T,shape}}
const VectorValues{dim,T,shape} = Union{CellVectorValues{dim,T,shape},FaceVectorValues{dim,T,shape}}
@noinline throw_detJ_not_pos(detJ) = throw(ArgumentError("det(J) is not positive: det(J) = $(detJ)"))

getnbasefunctions(cv::Values) = size(cv.N, 1)
getngeobasefunctions(cv::Values) = size(cv.M, 1)
const ScalarValues{dim,T,shape,func_interp,geo_interp} = Union{CellScalarValues{dim,T,shape},FaceScalarValues{dim,T,shape,func_interp,geo_interp}}
const VectorValues{dim,T,shape,func_interp,geo_interp} = Union{CellVectorValues{dim,T,shape},FaceVectorValues{dim,T,shape,func_interp,geo_interp}}

getn_scalarbasefunctions(cv::ScalarValues) = size(cv.N, 1)
getn_scalarbasefunctions(cv::VectorValues{dim}) where {dim} = size(cv.N, 1) ÷ dim
getnbasefunctions(fe::Values{dim}) where {dim} = getnbasefunctions(fe.func_interp) * (fe isa VectorValues ? dim : 1)
getngeobasefunctions(fe::Values) = getnbasefunctions(fe.geo_interp)

getn_scalarbasefunctions(cv::ScalarValues) = getnbasefunctions(cv)
getn_scalarbasefunctions(cv::VectorValues{dim}) where {dim} = getnbasefunctions(cv) ÷ dim

"""
reinit!(cv::CellValues, x::Vector)
Expand Down Expand Up @@ -104,8 +106,7 @@ where ``u_i`` are the value of ``u`` in the nodes. For a vector valued function
nodal values of ``\\mathbf{u}``.
"""
function function_value(fe_v::Values{dim}, q_point::Int, u::AbstractVector{T}, dof_range = eachindex(u)) where {dim,T}
n_base_funcs = getn_scalarbasefunctions(fe_v)
isa(fe_v, VectorValues) && (n_base_funcs *= dim)
n_base_funcs = getnbasefunctions(fe_v)
@assert length(dof_range) == n_base_funcs
@boundscheck checkbounds(u, dof_range)
val = zero(_valuetype(fe_v, u))
Expand Down Expand Up @@ -150,8 +151,7 @@ For a vector valued function the gradient is computed as
where ``\\mathbf{u}_i`` are the nodal values of ``\\mathbf{u}``.
"""
function function_gradient(fe_v::Values{dim}, q_point::Int, u::AbstractVector{T}, dof_range = eachindex(u)) where {dim,T}
n_base_funcs = getn_scalarbasefunctions(fe_v)
isa(fe_v, VectorValues) && (n_base_funcs *= dim)
n_base_funcs = getnbasefunctions(fe_v)
@assert length(dof_range) == n_base_funcs
@boundscheck checkbounds(u, dof_range)
grad = zero(_gradienttype(fe_v, u))
Expand Down
Loading