Skip to content
This repository has been archived by the owner on Jul 19, 2023. It is now read-only.

Commit

Permalink
Merge pull request #505 from eschnett/eschnett/sparse
Browse files Browse the repository at this point in the history
Specialize `sparse` for various operators
  • Loading branch information
ChrisRackauckas authored Jan 13, 2022
2 parents cc02902 + 67f9735 commit 4d26700
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 5 deletions.
16 changes: 14 additions & 2 deletions src/composite_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@
# operator types are lazy and maintain the structure used to build them.


# Define a helper function `sparse1` that handles
# `DiffEqArrayOperator` and `DiffEqScaledOperator`.
# We should define `sparse` for these types in `SciMLBase` instead,
# but that package doesn't know anything about sparse arrays yet, so
# we'll introduce a temporary work-around here.
sparse1(A) = sparse(A)
sparse1(A::DiffEqArrayOperator) = sparse1(A.A)
sparse1(A::DiffEqScaledOperator) = A.coeff * sparse1(A.op)


# Linear Combination
struct DiffEqOperatorCombination{T,O<:Tuple{Vararg{AbstractDiffEqLinearOperator{T}}},
C<:AbstractVector{T}} <: AbstractDiffEqCompositeOperator{T}
Expand All @@ -13,7 +23,7 @@ struct DiffEqOperatorCombination{T,O<:Tuple{Vararg{AbstractDiffEqLinearOperator{
for i in 2:length(ops)
@assert size(ops[i]) == size(ops[1]) "Operators must be of the same size to be combined! Mismatch between $(ops[i]) and $(ops[i-1]), which are operators $i and $(i-1) respectively"
end
if cache == nothing
if cache === nothing
cache = zeros(T, size(ops[1], 1))
end
new{T,typeof(ops),typeof(cache)}(ops, cache)
Expand All @@ -36,6 +46,7 @@ getops(L::DiffEqOperatorCombination) = L.ops
Matrix(L::DiffEqOperatorCombination) = sum(Matrix, L.ops)
convert(::Type{AbstractMatrix}, L::DiffEqOperatorCombination) =
sum(op -> convert(AbstractMatrix, op), L.ops)
SparseArrays.sparse(L::DiffEqOperatorCombination) = sum(sparse1, L.ops)

size(L::DiffEqOperatorCombination, args...) = size(L.ops[1], args...)
getindex(L::DiffEqOperatorCombination, i::Int) = sum(op -> op[i], L.ops)
Expand Down Expand Up @@ -64,7 +75,7 @@ struct DiffEqOperatorComposition{T,O<:Tuple{Vararg{AbstractDiffEqLinearOperator{
@assert size(ops[i-1], 1) == size(ops[i], 2) "Operations do not have compatible sizes! Mismatch between $(ops[i]) and $(ops[i-1]), which are operators $i and $(i-1) respectively."
end

if caches == nothing
if caches === nothing
# Construct a list of caches to be used by mul! and ldiv!
caches = []
for op in ops[1:end-1]
Expand All @@ -89,6 +100,7 @@ getops(L::DiffEqOperatorComposition) = L.ops
Matrix(L::DiffEqOperatorComposition) = prod(Matrix, reverse(L.ops))
convert(::Type{AbstractMatrix}, L::DiffEqOperatorComposition) =
prod(op -> convert(AbstractMatrix, op), reverse(L.ops))
SparseArrays.sparse(L::DiffEqOperatorComposition) = prod(sparse1, reverse(L.ops))

size(L::DiffEqOperatorComposition) = (size(L.ops[end], 1), size(L.ops[1], 2))
size(L::DiffEqOperatorComposition, m::Integer) = size(L)[m]
Expand Down
61 changes: 59 additions & 2 deletions src/derivative_operators/concretization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,68 @@ end
LinearAlgebra.Array(A::DerivativeOperator{T}, N::Int=A.len) where T =
copyto!(zeros(T, N, N+2), A, N)

SparseArrays.SparseMatrixCSC(A::DerivativeOperator{T}, N::Int=A.len) where T =
copyto!(spzeros(T, N, N+2), A, N)
function SparseArrays.SparseMatrixCSC(A::DerivativeOperator{T}, N::Int=A.len) where T
bl = A.boundary_point_count
stencil_length = A.stencil_length
stencil_pivot = use_winding(A) ? (1 + stencil_length%2) : div(stencil_length,2)
bstl = A.boundary_stencil_length

coeff = A.coefficients
get_coeff = if coeff isa AbstractVector
i -> coeff[i]
elseif coeff isa Number
i -> coeff
else
i -> true
end

Is = Int[]
Js = Int[]
Vs = T[]

nvalues = 2*bl * bstl + (N - 2*bl) * stencil_length
sizehint!(Is, nvalues)
sizehint!(Js, nvalues)
sizehint!(Vs, nvalues)

for i in 1:bl
cur_coeff = get_coeff(i)
cur_stencil = use_winding(A) && cur_coeff < 0 ? reverse(A.low_boundary_coefs[i]) : A.low_boundary_coefs[i]
append!(Is, ((i for j in 1:bstl)...))
append!(Js, 1:bstl)
append!(Vs, cur_coeff * cur_stencil)
end

for i in bl+1:N-bl
cur_coeff = get_coeff(i)
stencil = eltype(A.stencil_coefs) <: AbstractVector ? A.stencil_coefs[i-bl] : A.stencil_coefs
cur_stencil = use_winding(A) && cur_coeff < 0 ? reverse(stencil) : stencil
append!(Is, ((i for j in 1:stencil_length)...))
append!(Js, i+1-stencil_pivot:i-stencil_pivot+stencil_length)
append!(Vs, cur_coeff * cur_stencil)
end

for i in N-bl+1:N
cur_coeff = get_coeff(i)
cur_stencil = use_winding(A) && cur_coeff < 0 ? reverse(A.high_boundary_coefs[i-N+bl]) : A.high_boundary_coefs[i-N+bl]
append!(Is, ((i for j in N-bstl+3:N+2)...))
append!(Js, N-bstl+3:N+2)
append!(Vs, cur_coeff * cur_stencil)
end

# ensure efficient allocation
@assert length(Is) == nvalues
@assert length(Js) == nvalues
@assert length(Vs) == nvalues

return sparse(Is, Js, Vs, N, N+2)
end

SparseArrays.sparse(A::DerivativeOperator{T}, N::Int=A.len) where T = SparseMatrixCSC(A,N)

Base.copyto!(L::AbstractSparseArray{T}, A::DerivativeOperator{T}, N::Int) where T =
copyto!(L, sparse(A))

function BandedMatrices.BandedMatrix(A::DerivativeOperator{T}, N::Int=A.len) where T
stencil_length = A.stencil_length
bstl = A.boundary_stencil_length
Expand Down
4 changes: 3 additions & 1 deletion test/DerivativeOperators/composite_operators_interface.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Test, LinearAlgebra, Random, DiffEqOperators
using Test, LinearAlgebra, Random, SparseArrays, DiffEqOperators
using DiffEqBase
using DiffEqBase: isconstant
using DiffEqOperators: DiffEqScaledOperator, DiffEqOperatorCombination, DiffEqOperatorComposition
Expand All @@ -22,6 +22,8 @@ using DiffEqOperators: DiffEqScaledOperator, DiffEqOperatorCombination, DiffEqOp
@test opnorm(L) opnorm(Lfull)
@test size(L) == size(Lfull)
@test L[1,2] Lfull[1,2]
Lsparse = sparse(L)
@test Lsparse == Lfull
u = [1.0, 2.0]; du = zeros(2)
@test L * u Lfull * u
mul!(du, L, u); @test du Lfull * u
Expand Down
2 changes: 2 additions & 0 deletions test/DerivativeOperators/generic_operator_validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ for dor in 1:4, aor in 2:2:6
Dr = CenteredDifference(dor,aor,dx[1],length(x)-2)
Dir = CenteredDifference(dor,aor,dx,length(x)-2)

@test sparse(Dr)==Array(Dr)

@test sparse(Dr)sparse(Dir)
@test Array(Dr)Array(Dir)

Expand Down

0 comments on commit 4d26700

Please sign in to comment.