Skip to content

Commit

Permalink
Merge pull request #45 from tpapp/tp/remove-unpack
Browse files Browse the repository at this point in the history
don't use unpack
  • Loading branch information
tpapp authored Nov 2, 2023
2 parents 1013e3b + 4ba2b23 commit 5835827
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 42 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@ version = "0.13.0"
[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
ArgCheck = "1, 2"
DocStringExtensions = "0.8, 0.9"
SimpleUnPack = "1"
StaticArrays = "1"
julia = "1.9"

Expand Down
1 change: 0 additions & 1 deletion src/SpectralKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module SpectralKit
using ArgCheck: @argcheck
using DocStringExtensions: FUNCTIONNAME, SIGNATURES, TYPEDEF
using StaticArrays: MVector, SVector, sacollect
using SimpleUnPack: @unpack

include("utilities.jl")
include("derivatives.jl")
Expand Down
12 changes: 6 additions & 6 deletions src/chebyshev.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ end
@inline domain(::Chebyshev) = PM1()

function Base.show(io::IO, chebyshev::Chebyshev)
@unpack grid_kind, N = chebyshev
(; grid_kind, N) = chebyshev
print(io, "Chebyshev polynomials (1st kind), ", grid_kind, ", dimension: ", N)
end

Expand All @@ -63,7 +63,7 @@ function Base.iterate(itr::ChebyshevIterator{T}) where T
end

function Base.iterate(itr::ChebyshevIterator{T}, (i, fp, fpp)) where T
@unpack x, N = itr
(; x, N) = itr
i > N && return nothing
f = _sub(_mul(2, x, fp), fpp)
f::T, (i + 1, f, fp)
Expand All @@ -86,13 +86,13 @@ broadened as required. Methods are type stable.
[`grid`](@ref), which is part of the API, this function isn't.
"""
function gridpoint(::Type{T}, basis::Chebyshev{InteriorGrid}, i::Integer) where {T <: Real}
@unpack N = basis
(; N) = basis
@argcheck 1 i N # FIXME use boundscheck
sinpi((N - 2 * i + 1) / T(2 * N))::T # use formula from Xu (2016)
end

function gridpoint(::Type{T}, basis::Chebyshev{EndpointGrid}, i::Integer) where {T <: Real}
@unpack N = basis
(; N)= basis
@argcheck 1 i N # FIXME use boundscheck
if N == 1
cospi(1/T(2))::T # 0.0 as a fallback, even though it does not have endpoints
Expand All @@ -102,7 +102,7 @@ function gridpoint(::Type{T}, basis::Chebyshev{EndpointGrid}, i::Integer) where
end

function gridpoint(::Type{T}, basis::Chebyshev{InteriorGrid2}, i::Integer) where {T <: Real}
@unpack N = basis
(; N)= basis
@argcheck 1 i N # FIXME use boundscheck
cospi(((N - i + 1) ./ T(N + 1)))::T
end
Expand All @@ -116,7 +116,7 @@ Base.eltype(::Type{<:ChebyshevGridIterator{T}}) where {T} = T
Base.length(itr::ChebyshevGridIterator) = dimension(itr.basis)

function Base.iterate(itr::ChebyshevGridIterator{T}, i = 1) where {T}
@unpack basis = itr
(; basis) = itr
if i dimension(basis)
gridpoint(T, basis, i), i + 1
else
Expand Down
2 changes: 1 addition & 1 deletion src/generic_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ struct TransformedLinearCombination{B,C,T}
end

function (l::TransformedLinearCombination)(x)
@unpack basis, θ, transformation = l
(; basis, θ, transformation) = l
_linear_combination(basis, θ, transform_to(domain(basis), transformation, x), false)
end

Expand Down
21 changes: 10 additions & 11 deletions src/smolyak_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ struct SmolyakIndices{N,H,B,M,Mp1}
end

function Base.show(io::IO, smolyak_indices::SmolyakIndices{N,H,B,M}) where {N,H,B,M}
@unpack len = smolyak_indices
(; len) = smolyak_indices
print(io, "Smolyak indexing, ∑bᵢ ≤ $(B), all bᵢ ≤ $(M), dimension $(len)")
end

Expand Down Expand Up @@ -161,7 +161,7 @@ struct SmolyakBasis{I<:SmolyakIndices,U<:UnivariateBasis} <: MultivariateBasis
end

function Base.show(io::IO, smolyak_basis::SmolyakBasis{<:SmolyakIndices{N}}) where N
@unpack smolyak_indices, univariate_parent = smolyak_basis
(; smolyak_indices, univariate_parent) = smolyak_basis
print(io, "Sparse multivariate basis on ℝ", SuperScript(N), "\n ", smolyak_indices,
"\n using ", univariate_parent)
end
Expand Down Expand Up @@ -219,8 +219,7 @@ end
end

function domain(smolyak_basis::SmolyakBasis{<:SmolyakIndices{N}}) where {N}
@unpack univariate_parent = smolyak_basis
D = domain(univariate_parent)
D = domain(smolyak_basis.univariate_parent)
coordinate_domains(Val(N), D)
end

Expand All @@ -240,17 +239,17 @@ end
function basis_at(smolyak_basis::SmolyakBasis{<:SmolyakIndices{N}},
x::Union{Tuple,AbstractVector}) where {N}
@argcheck length(x) == N
@unpack smolyak_indices= smolyak_basis
SmolyakProduct(smolyak_indices, _univariate_bases_at(smolyak_basis, NTuple{N}(x)),
SmolyakProduct(smolyak_basis.smolyak_indices,
_univariate_bases_at(smolyak_basis, NTuple{N}(x)),
nothing)
end

function basis_at(smolyak_basis::SmolyakBasis{<:SmolyakIndices{N}},
Lx::∂InputLifted) where {N}
(; ∂specification, lifted_x) = Lx
@argcheck length(lifted_x) == N
@unpack smolyak_indices = smolyak_basis
SmolyakProduct(smolyak_indices, _univariate_bases_at(smolyak_basis, lifted_x),
SmolyakProduct(smolyak_basis.smolyak_indices,
_univariate_bases_at(smolyak_basis, lifted_x),
∂specification)
end

Expand All @@ -267,14 +266,14 @@ Base.length(itr::SmolyakGridIterator) = length(itr.smolyak_indices)

function grid(::Type{T},
smolyak_basis::SmolyakBasis{<:SmolyakIndices{N,H}}) where {T<:Real,N,H}
@unpack smolyak_indices, univariate_parent = smolyak_basis
(; smolyak_indices, univariate_parent) = smolyak_basis
sources = sacollect(SVector{H}, gridpoint(T, univariate_parent, i)
for i in SmolyakGridShuffle(univariate_parent.grid_kind, H))
SmolyakGridIterator{NTuple{N,T},typeof(smolyak_indices),typeof(sources)}(smolyak_indices, sources)
end

function Base.iterate(itr::SmolyakGridIterator, state...)
@unpack smolyak_indices, sources = itr
(; smolyak_indices, sources) = itr
result = iterate(smolyak_indices, state...)
result nothing && return nothing
ι, state′ = result
Expand Down Expand Up @@ -328,7 +327,7 @@ Base.eltype(itr::PaddingIterator) = eltype(itr.θ1)

function Base.iterate(itr::PaddingIterator, state = (firstindex(itr.θ1),
iterate(itr.itr1)...))
@unpack θ1, itr1, itr2 = itr
(; θ1, itr1, itr2) = itr
i, ι1, state1, state2... = state
res2 = iterate(itr2, state2...)
res2 nothing && return nothing
Expand Down
15 changes: 6 additions & 9 deletions src/smolyak_traversal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,12 @@ Length of each block `b`.
end

function Base.iterate::SmolyakGridShuffle{EndpointGrid})
@unpack len = ι
i = (len + 1) ÷ 2
i =.len + 1) ÷ 2
i, (0, 0) # step = 0 is special-cased
end

function Base.iterate::SmolyakGridShuffle{EndpointGrid}, (i, step))
@unpack len = ι
(; len) = ι
i == 0 && return len > 1 ? (1, (1, len - 1)) : nothing
i′ = i + step
if i′ len
Expand All @@ -83,15 +82,15 @@ end
end

function Base.iterate::SmolyakGridShuffle{InteriorGrid})
@unpack len = ι
(; len) = ι
i0 = (len + 1) ÷ 2 # first index at this level
Δ = len # basis for step size
a = 2 # alternating as 2Δa and Δa
i0, (i0, i0, Δ, a)
end

function Base.iterate::SmolyakGridShuffle{InteriorGrid}, (i, i0, Δ, a))
@unpack len = ι
(; len) = ι
i′ = i + a * Δ
if i′ len
i′, (i′, i0, Δ, 3 - a)
Expand All @@ -115,15 +114,13 @@ end
@inline nesting_block_length(::Type{Chebyshev}, ::InteriorGrid2, b::Int) = 1 << b

function Base.iterate::SmolyakGridShuffle{InteriorGrid2})
@unpack len = ι
i = (len + 1) ÷ 2
i =.len + 1) ÷ 2
i, (i, 2 * i)
end

function Base.iterate::SmolyakGridShuffle{InteriorGrid2}, (i, step))
@unpack len = ι
i′ = i + step
if i′ len
if i′ ι.len
i′, (i′, step)
else
step′ = step ÷ 2
Expand Down
24 changes: 12 additions & 12 deletions src/transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ end
coordinate_transformations(transformations...) = coordinate_transformations(transformations)

function transform_to(domain::CoordinateDomains, ct::CoordinateTransformations, x::Tuple)
@unpack domains = domain
@unpack transformations = ct
(; domains) = domain
(; transformations) = ct
@argcheck length(domains) == length(transformations) == length(x)
map((d, t, x) -> transform_to(d, t, x), domains, transformations, x)
end
Expand All @@ -125,8 +125,8 @@ function transform_to(domain::CoordinateDomains, ct::CoordinateTransformations,
end

function transform_from(domain::CoordinateDomains, ct::CoordinateTransformations, x::Tuple)
@unpack domains = domain
@unpack transformations = ct
(; domains) = domain
(; transformations) = ct
@argcheck length(domains) == length(transformations) == length(x)
map((d, t, x) -> transform_from(d, t, x), domains, transformations, x)
end
Expand Down Expand Up @@ -160,7 +160,7 @@ struct BoundedLinear{T <: Real} <: AbstractUnivariateTransformation
end

function Base.show(io::IO, transformation::BoundedLinear)
@unpack m, s = transformation
(; m, s) = transformation
print(io, "(", m - s, ",", m + s, ") ↔ domain [linear transformation]")
end

Expand All @@ -174,12 +174,12 @@ Transform the domain to `y ∈ (a, b)`, using ``y = x ⋅ s + m``.
BoundedLinear(a::Real, b::Real) = BoundedLinear(promote(a, b)...)

function transform_from(::PM1, t::BoundedLinear, x::Scalar)
@unpack m, s = t
(; m, s) = t
x * s + m
end

function transform_to(::PM1, t::BoundedLinear, y::Real)
@unpack m, s = t
(; m, s) = t
(y - m) / s
end

Expand All @@ -193,7 +193,7 @@ function transform_to(domain::PM1, t::BoundedLinear, y::Derivatives{N}) where N
end

function domain(t::BoundedLinear)
@unpack m, s = t
(; m, s) = t
UnivariateDomain(m - s, m + s)
end

Expand All @@ -214,7 +214,7 @@ struct SemiInfRational{T<:Real} <: AbstractUnivariateTransformation
end

function Base.show(io::IO, transformation::SemiInfRational)
@unpack A, L = transformation
(; A, L) = transformation
if L > 0
D = "($A,∞)"
else
Expand Down Expand Up @@ -265,7 +265,7 @@ function transform_to(domain::PM1, t::SemiInfRational, y::Derivatives{N}) where
end

function domain(t::SemiInfRational)
@unpack L, A = t
(; L, A) = t
= oftype(A, Inf)
L > 0 ? UnivariateDomain(A, ∞) : UnivariateDomain(-∞, A)
end
Expand All @@ -287,7 +287,7 @@ struct InfRational{T <: Real} <: AbstractUnivariateTransformation
end

function Base.show(io::IO, transformation::InfRational)
@unpack A, L = transformation
(; A, L) = transformation
print(io, "(-∞,∞) ↔ domain [rational transformation with center ", A, ", scale ", L, "]")
end

Expand All @@ -306,7 +306,7 @@ InfRational(A::Real, L::Real) = InfRational(promote(A, L)...)
transform_from(::PM1, T::InfRational, x::Real) = T.A + T.L * x / (1 - abs2(x))

function transform_to(::PM1, t::InfRational, y::Real)
@unpack A, L = t
(; A, L) = t
z = y - A
x = z / hypot(z, L)
if isinf(y)
Expand Down

0 comments on commit 5835827

Please sign in to comment.