Skip to content

Commit

Permalink
Fix ==/sum/issymmetric for (Sym)Tridiagonal with non-number elt…
Browse files Browse the repository at this point in the history
…ype (#43066)
  • Loading branch information
N5N3 authored Nov 16, 2021
1 parent 9affe2f commit bc8337a
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 24 deletions.
51 changes: 27 additions & 24 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,13 @@ AbstractMatrix{T}(S::SymTridiagonal) where {T} =
function Matrix{T}(M::SymTridiagonal) where T
n = size(M, 1)
Mf = zeros(T, n, n)
if n == 0
return Mf
end
@inbounds begin
@simd for i = 1:n-1
Mf[i,i] = M.dv[i]
Mf[i+1,i] = M.ev[i]
Mf[i,i+1] = M.ev[i]
end
Mf[n,n] = M.dv[n]
n == 0 && return Mf
@inbounds for i = 1:n-1
Mf[i,i] = symmetric(M.dv[i], :U)
Mf[i+1,i] = transpose(M.ev[i])
Mf[i,i+1] = M.ev[i]
end
Mf[n,n] = symmetric(M.dv[n], :U)
return Mf
end
Matrix(M::SymTridiagonal{T}) where {T} = Matrix{T}(M)
Expand Down Expand Up @@ -612,8 +608,8 @@ transpose(S::Tridiagonal{<:Number}) = Tridiagonal(S.du, S.d, S.dl)
Base.copy(aS::Adjoint{<:Any,<:Tridiagonal}) = (S = aS.parent; Tridiagonal(map(x -> copy.(adjoint.(x)), (S.du, S.d, S.dl))...))
Base.copy(tS::Transpose{<:Any,<:Tridiagonal}) = (S = tS.parent; Tridiagonal(map(x -> copy.(transpose.(x)), (S.du, S.d, S.dl))...))

ishermitian(S::Tridiagonal) = isreal(S.d) && S.du == adjoint.(S.dl)
issymmetric(S::Tridiagonal) = S.du == S.dl
ishermitian(S::Tridiagonal) = all(ishermitian, S.d) && all(Iterators.map((x, y) -> x == y', S.du, S.dl))
issymmetric(S::Tridiagonal) = all(issymmetric, S.d) && all(Iterators.map((x, y) -> x == transpose(y), S.du, S.dl))

\(A::Adjoint{<:Any,<:Tridiagonal}, B::Adjoint{<:Any,<:StridedVecOrMat}) = copy(A) \ B

Expand Down Expand Up @@ -744,8 +740,12 @@ end
\(B::Number, A::Tridiagonal) = Tridiagonal(B\A.dl, B\A.d, B\A.du)

==(A::Tridiagonal, B::Tridiagonal) = (A.dl==B.dl) && (A.d==B.d) && (A.du==B.du)
==(A::Tridiagonal, B::SymTridiagonal) = (A.dl==A.du==B.ev) && (A.d==B.dv)
==(A::SymTridiagonal, B::Tridiagonal) = (B.dl==B.du==A.ev) && (B.d==A.dv)
function ==(A::Tridiagonal, B::SymTridiagonal)
iseq = all(Iterators.map((x, y) -> x == transpose(y), A.du, A.dl))
iseq = iseq && A.du == _evview(B)
iseq && all(Iterators.map((x, y) -> x == symmetric(y, :U), A.d, B.dv))
end
==(A::SymTridiagonal, B::Tridiagonal) = B == A

det(A::Tridiagonal) = det_usmani(A.dl, A.d, A.du)

Expand All @@ -760,7 +760,10 @@ function SymTridiagonal{T}(M::Tridiagonal) where T
end

Base._sum(A::Tridiagonal, ::Colon) = sum(A.d) + sum(A.dl) + sum(A.du)
Base._sum(A::SymTridiagonal, ::Colon) = sum(A.dv) + 2sum(A.ev)
function Base._sum(A::SymTridiagonal, ::Colon)
se = sum(_evview(A))
symmetric(sum(A.dv), :U) + se + transpose(se)
end

function Base._sum(A::Tridiagonal, dims::Integer)
res = Base.reducedim_initarray(A, dims, zero(eltype(A)))
Expand Down Expand Up @@ -807,24 +810,24 @@ function Base._sum(A::SymTridiagonal, dims::Integer)
end
@inbounds begin
if dims == 1
res[1] = A.ev[1] + A.dv[1]
res[1] = transpose(A.ev[1]) + symmetric(A.dv[1], :U)
for i = 2:n-1
res[i] = A.ev[i] + A.dv[i] + A.ev[i-1]
res[i] = transpose(A.ev[i]) + symmetric(A.dv[i], :U) + A.ev[i-1]
end
res[n] = A.dv[n] + A.ev[n-1]
res[n] = symmetric(A.dv[n], :U) + A.ev[n-1]
elseif dims == 2
res[1] = A.dv[1] + A.ev[1]
res[1] = symmetric(A.dv[1], :U) + A.ev[1]
for i = 2:n-1
res[i] = A.ev[i-1] + A.dv[i] + A.ev[i]
res[i] = transpose(A.ev[i-1]) + symmetric(A.dv[i], :U) + A.ev[i]
end
res[n] = A.ev[n-1] + A.dv[n]
res[n] = transpose(A.ev[n-1]) + symmetric(A.dv[n], :U)
elseif dims >= 3
for i = 1:n-1
res[i,i+1] = A.ev[i]
res[i,i] = A.dv[i]
res[i+1,i] = A.ev[i]
res[i,i] = symmetric(A.dv[i], :U)
res[i+1,i] = transpose(A.ev[i])
end
res[n,n] = A.dv[n]
res[n,n] = symmetric(A.dv[n], :U)
end
end
res
Expand Down
31 changes: 31 additions & 0 deletions stdlib/LinearAlgebra/test/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -695,4 +695,35 @@ end
end
end

isdefined(Main, :SizedArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "SizedArrays.jl"))
using .Main.SizedArrays
@testset "non-number eltype" begin
@testset "sum for SymTridiagonal" begin
dv = [SizedArray{(2,2)}(rand(1:2048,2,2)) for i in 1:10]
ev = [SizedArray{(2,2)}(rand(1:2048,2,2)) for i in 1:10]
S = SymTridiagonal(dv, ev)
Sdense = Matrix(S)
@test Sdense == collect(S)
@test sum(S) == sum(Sdense)
@test sum(S, dims = 1) == sum(Sdense, dims = 1)
@test sum(S, dims = 2) == sum(Sdense, dims = 2)
end
@testset "issymmetric/ishermitian for Tridiagonal" begin
@test !issymmetric(Tridiagonal([[1 2;3 4]], [[1 2;2 3], [1 2;2 3]], [[1 2;3 4]]))
@test !issymmetric(Tridiagonal([[1 3;2 4]], [[1 2;3 4], [1 2;3 4]], [[1 2;3 4]]))
@test issymmetric(Tridiagonal([[1 3;2 4]], [[1 2;2 3], [1 2;2 3]], [[1 2;3 4]]))

@test ishermitian(Tridiagonal([[1 3;2 4].+im], [[1 2;2 3].+0im, [1 2;2 3].+0im], [[1 2;3 4].-im]))
@test !ishermitian(Tridiagonal([[1 3;2 4].+im], [[1 2;2 3].+0im, [1 2;2 3].+0im], [[1 2;3 4].+im]))
@test !ishermitian(Tridiagonal([[1 3;2 4].+im], [[1 2;2 3].+im, [1 2;2 3].+0im], [[1 2;3 4].-im]))
end
@testset "== between Tridiagonal and SymTridiagonal" begin
dv = [SizedArray{(2,2)}([1 2;3 4]) for i in 1:4]
ev = [SizedArray{(2,2)}([3 4;1 2]) for i in 1:4]
S = SymTridiagonal(dv, ev)
Sdense = Matrix(S)
@test S == Tridiagonal(diag(Sdense, -1), diag(Sdense), diag(Sdense, 1)) == S
@test S !== Tridiagonal(diag(Sdense, 1), diag(Sdense), diag(Sdense, 1)) !== S
end
end
end # module TestTridiagonal
40 changes: 40 additions & 0 deletions test/testhelpers/SizedArrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

# SizedArrays

# This test file defines an array wrapper with statical size. It can be used to
# test the action of LinearAlgebra with non-number eltype.

module SizedArrays

import Base: +, *, ==

export SizedArray

struct SizedArray{SZ,T,N,A<:AbstractArray} <: AbstractArray{T,N}
data::A
function SizedArray{SZ}(data::AbstractArray{T,N}) where {SZ,T,N}
SZ == size(data) || throw(ArgumentError("size mismatch!"))
new{SZ,T,N,typeof(data)}(data)
end
function SizedArray{SZ,T,N,A}(data::AbstractArray{T,N}) where {SZ,T,N,A}
SZ == size(data) || throw(ArgumentError("size mismatch!"))
new{SZ,T,N,A}(A(data))
end
end
Base.convert(::Type{SizedArray{SZ,T,N,A}}, data::AbstractArray) where {SZ,T,N,A} = SizedArray{SZ,T,N,A}(data)

# Minimal AbstractArray interface
Base.size(a::SizedArray) = size(typeof(a))
Base.size(::Type{<:SizedArray{SZ}}) where {SZ} = SZ
Base.getindex(A::SizedArray, i...) = getindex(A.data, i...)
Base.zero(::Type{T}) where T <: SizedArray = SizedArray{size(T)}(zeros(eltype(T), size(T)))
+(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = SizedArray{SZ}(S1.data + S2.data)
==(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = S1.data == S2.data
function *(S1::SizedArray, S2::SizedArray)
0 < ndims(S1) < 3 && 0 < ndims(S2) < 3 && size(S1, 2) == size(S2, 1) || throw(ArgumentError("size mismatch!"))
data = S1.data * S2.data
SZ = ndims(data) == 1 ? (size(S1, 1), ) : (size(S1, 1), size(S2, 2))
SizedArray{SZ}(data)
end
end

0 comments on commit bc8337a

Please sign in to comment.