Skip to content

Commit

Permalink
Merge pull request #28 from kylebeggs/feature/clean-up
Browse files Browse the repository at this point in the history
Feature/clean up
  • Loading branch information
kylebeggs authored Aug 7, 2024
2 parents 1756a08 + a59546d commit d19d5d6
Show file tree
Hide file tree
Showing 16 changed files with 156 additions and 138 deletions.
26 changes: 0 additions & 26 deletions .travis.yml

This file was deleted.

5 changes: 3 additions & 2 deletions src/RadialBasisFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export AbstractPHS, PHS, PHS1, PHS3, PHS5, PHS7
export IMQ
export Gaussian
export MonomialBasis
export degree, dim

include("utils.jl")
export find_neighbors, reorder_points!
Expand All @@ -42,7 +43,7 @@ export ∂virtual

include("operators/monomial/monomial.jl")

include("operators/operator_combinations.jl")
include("operators/operator_algebra.jl")

include("interpolation.jl")
export Interpolator
Expand All @@ -52,7 +53,7 @@ export Regrid, regrid

# Some consts and aliases
const Δ = ∇² # some people like this notation for the Laplacian
const AVOID_NAN = 1e-16
const AVOID_INF = 1e-16

using PrecompileTools
@setup_workload begin
Expand Down
3 changes: 3 additions & 0 deletions src/basis/basis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ function (ℒmon::ℒMonomialBasis{Dim,Deg})(x) where {Dim,Deg}
end
(m::ℒMonomialBasis)(b, x) = m.f(b, x)

degree(::ℒMonomialBasis{Dim,Deg}) where {Dim,Deg} = Deg
dim(::ℒMonomialBasis{Dim,Deg}) where {Dim,Deg} = Dim

include("polyharmonic_spline.jl")
include("inverse_multiquadric.jl")
include("gaussian.jl")
Expand Down
10 changes: 5 additions & 5 deletions src/basis/polyharmonic_spline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ end

(phs::PHS1)(x, xᵢ) = euclidean(x, xᵢ)
function (::PHS1, dim::Int)
∂ℒ(x, xᵢ) = (x[dim] - xᵢ[dim]) / (euclidean(x, xᵢ) + AVOID_NAN)
∂ℒ(x, xᵢ) = (x[dim] - xᵢ[dim]) / (euclidean(x, xᵢ) + AVOID_INF)
return ℒRadialBasisFunction(∂ℒ)
end
function (::PHS1)
Expand All @@ -49,14 +49,14 @@ end
function ∂²(::PHS1, dim::Int)
function ∂²ℒ(x, xᵢ)
return (-(x[dim] - xᵢ[dim])^2 + sqeuclidean(x, xᵢ)) /
(euclidean(x, xᵢ)^3 + AVOID_NAN)
(euclidean(x, xᵢ)^3 + AVOID_INF)
end
return ℒRadialBasisFunction(∂²ℒ)
end
function ∇²(::PHS1)
function ∇²ℒ(x, xᵢ)
return sum(
(-(x .- xᵢ) .^ 2 .+ sqeuclidean(x, xᵢ)) / (euclidean(x, xᵢ)^3 + AVOID_NAN)
(-(x .- xᵢ) .^ 2 .+ sqeuclidean(x, xᵢ)) / (euclidean(x, xᵢ)^3 + AVOID_INF)
)
end
return ℒRadialBasisFunction(∇²ℒ)
Expand Down Expand Up @@ -87,14 +87,14 @@ end
function ∂²(::PHS3, dim::Int)
function ∂²ℒ(x, xᵢ)
return 3 * (sqeuclidean(x, xᵢ) + (x[dim] - xᵢ[dim])^2) /
(euclidean(x, xᵢ) + AVOID_NAN)
(euclidean(x, xᵢ) + AVOID_INF)
end
return ℒRadialBasisFunction(∂²ℒ)
end
function ∇²(::PHS3)
function ∇²ℒ(x, xᵢ)
return sum(
3 * (sqeuclidean(x, xᵢ) .+ (x .- xᵢ) .^ 2) / (euclidean(x, xᵢ) + AVOID_NAN)
3 * (sqeuclidean(x, xᵢ) .+ (x .- xᵢ) .^ 2) / (euclidean(x, xᵢ) + AVOID_INF)
)
end
return ℒRadialBasisFunction(∇²ℒ)
Expand Down
34 changes: 22 additions & 12 deletions src/operators/directional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ end
Builds a `RadialBasisOperator` where the operator is the directional derivative, `Directional`.
"""
function directional(
data::AbstractVector{D},
data::AbstractVector,
v::AbstractVector,
basis::B=PHS(3; poly_deg=2);
k::T=autoselect_k(data, basis),
adjl=find_neighbors(data, k),
) where {D<:AbstractArray,B<:AbstractRadialBasis,T<:Int}
) where {B<:AbstractRadialBasis,T<:Int}
f = ntuple(dim -> Base.Fix2(∂, dim), length(first(data)))
= Directional(f, v)
return RadialBasisOperator(ℒ, data, basis; k=k, adjl=adjl)
Expand All @@ -45,14 +45,12 @@ end

function RadialBasisOperator(
::Directional,
data::AbstractVector{D},
data::AbstractVector,
basis::B=PHS(3; poly_deg=2);
k::T=autoselect_k(data, basis),
adjl=find_neighbors(data, k),
) where {D<:AbstractArray,T<:Int,B<:AbstractRadialBasis}
Na = length(adjl)
Nd = length(data)
weights = spzeros(eltype(D), Na, Nd)
) where {T<:Int,B<:AbstractRadialBasis}
weights = _build_weights(ℒ, data, data, adjl, basis)
return RadialBasisOperator(ℒ, weights, data, data, adjl, basis)
end

Expand All @@ -64,12 +62,26 @@ function RadialBasisOperator(
k::T=autoselect_k(data, basis),
adjl=find_neighbors(data, eval_points, k),
) where {TD,TE,T<:Int,B<:AbstractRadialBasis}
Na = length(adjl)
Nd = length(data)
weights = spzeros(eltype(TD), Na, Nd)
weights = _build_weights(ℒ, data, eval_points, adjl, basis)
return RadialBasisOperator(ℒ, weights, data, eval_points, adjl, basis)
end

function _build_weights(ℒ::Directional, data, eval_points, adjl, basis)
v =.v
N = length(first(data))
@assert length(v) == N || length(v) == length(data) "wrong size for v"
if length(v) == N
return mapreduce(+, enumerate(ℒ.ℒ)) do (i, ℒ)
_build_weights(ℒ, data, eval_points, adjl, basis) * v[i]
end
else
vv = ntuple(i -> getindex.(v, i), N)
return mapreduce(+, enumerate(ℒ.ℒ)) do (i, ℒ)
Diagonal(vv[i]) * _build_weights(ℒ, data, eval_points, adjl, basis)
end
end
end

function update_weights!(op::RadialBasisOperator{<:Directional})
v = op..v
N = length(first(op.data))
Expand All @@ -88,7 +100,5 @@ function update_weights!(op::RadialBasisOperator{<:Directional})
return nothing
end

Base.size(op::RadialBasisOperator{<:Directional}) = size(op.weights)

# pretty printing
print_op(op::Directional) = "Directional Gradient (∇f⋅v)"
4 changes: 2 additions & 2 deletions src/operators/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ end
Builds a `RadialBasisOperator` where the operator is the gradient, `Gradient`.
"""
function gradient(
data::AbstractVector{D},
data::AbstractVector,
basis::B=PHS(3; poly_deg=2);
k::T=autoselect_k(data, basis),
adjl=find_neighbors(data, k),
) where {D<:AbstractArray,B<:AbstractRadialBasis,T<:Int}
) where {B<:AbstractRadialBasis,T<:Int}
f = ntuple(dim -> Base.Fix2(∂, dim), length(first(data)))
= Gradient(f)
return RadialBasisOperator(ℒ, data, basis; k=k, adjl=adjl)
Expand Down
4 changes: 2 additions & 2 deletions src/operators/laplacian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ end

# convienience constructors
function laplacian(
data::AbstractVector{D},
data::AbstractVector,
basis::B=PHS(3; poly_deg=2);
k::T=autoselect_k(data, basis),
adjl=find_neighbors(data, k),
) where {D<:AbstractArray,T<:Int,B<:AbstractRadialBasis}
) where {T<:Int,B<:AbstractRadialBasis}
= Laplacian(∇²)
return RadialBasisOperator(ℒ, data, basis; k=k, adjl=adjl)
end
Expand Down
9 changes: 4 additions & 5 deletions src/operators/monomial/monomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@ end
function ∇²(m::MonomialBasis{Dim,Deg}) where {Dim,Deg}
∂² = ntuple(dim -> (m, 2, dim), Dim)
function basis!(b, x)
cache = ones(size(b))
b .= 0
cache = ones(eltype(x), size(b))
b .= zero(eltype(x))
for ∂²! in ∂²
# use mapreduce here instead?
∂²!(cache, x)
b .+= cache
end
Expand All @@ -34,8 +33,8 @@ end

function build_monomial_basis(ids::Vector{Vector{Vector{T}}}, c::Vector{T}) where {T<:Int}
function basis!(db::AbstractVector{B}, x::AbstractVector) where {B}
db .= 1
# TODO flatten loop - why does it allocate here
db .= one(eltype(x))
# TODO optimize - allocations
@views @inbounds for i in eachindex(ids), j in eachindex(ids[i])
db[ids[i][j]] *= x[i]
end
Expand Down
45 changes: 45 additions & 0 deletions src/operators/operator_algebra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
for op in (:+, :-)
@eval function Base.$op(a::ℒRadialBasisFunction, b::ℒRadialBasisFunction)
additive_ℒRBF(x, xᵢ) = Base.$op(a(x, xᵢ), b(x, xᵢ))
return ℒRadialBasisFunction(additive_ℒRBF)
end
end

for op in (:+, :-)
@eval function Base.$op(
a::ℒMonomialBasis{Dim,Deg}, b::ℒMonomialBasis{Dim,Deg}
) where {Dim,Deg}
function additive_ℒMon(m, x)
m .= Base.$op.(a(x), b(x))
return nothing
end
return ℒMonomialBasis(Dim, Deg, additive_ℒMon)
end
end

for op in (:+, :-)
@eval function Base.$op(op1::RadialBasisOperator, op2::RadialBasisOperator)
_check_compatible(op1, op2)
k = _update_stencil(op1, op2)
(x) = Base.$op(op1.(x), op2.(x))
return RadialBasisOperator(ℒ, op1.data, op1.basis; k=k, adjl=op1.adjl)
end
end

function _check_compatible(op1::RadialBasisOperator, op2::RadialBasisOperator)
if !all(op1.data .≈ op2.data)
throw(
ArgumentError("Can not add operators that were not built with the same data.")
)
end
if !all(op1.adjl .≈ op2.adjl)
throw(ArgumentError("Can not add operators that do not have the same stencils."))
end
end

function _update_stencil(op1::RadialBasisOperator, op2::RadialBasisOperator)
k1 = length(first((op1.adjl)))
k2 = length(first((op2.adjl)))
k = k1 > k2 ? k1 : k2
return k
end
50 changes: 0 additions & 50 deletions src/operators/operator_combinations.jl

This file was deleted.

Loading

0 comments on commit d19d5d6

Please sign in to comment.