Skip to content

Commit

Permalink
Optimizations for diagonal HermOrSym (#48189)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch authored Jan 30, 2023
1 parent 45b7e7a commit 6ab660d
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 10 deletions.
5 changes: 5 additions & 0 deletions stdlib/LinearAlgebra/src/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ real(H::UpperHessenberg{<:Real}) = H
real(H::UpperHessenberg{<:Complex}) = UpperHessenberg(triu!(real(H.data),-1))
imag(H::UpperHessenberg) = UpperHessenberg(triu!(imag(H.data),-1))

function istriu(A::UpperHessenberg, k::Integer=0)
k <= -1 && return true
return _istriu(A, k)
end

function Matrix{T}(H::UpperHessenberg) where T
m,n = size(H)
return triu!(copyto!(Matrix{T}(undef, m, n), H.data), -1)
Expand Down
4 changes: 4 additions & 0 deletions stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,10 @@ end
zero(D::Diagonal) = Diagonal(zero.(D.diag))
oneunit(D::Diagonal) = Diagonal(oneunit.(D.diag))

isdiag(A::HermOrSym{<:Any,<:Diagonal}) = isdiag(parent(A))
dot(x::AbstractVector, A::RealHermSymComplexSym{<:Real,<:Diagonal}, y::AbstractVector) =
dot(x, A.data, y)

# equals and approx equals methods for structured matrices
# SymTridiagonal == Tridiagonal is already defined in tridiag.jl

Expand Down
15 changes: 11 additions & 4 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ end
diag(A::Symmetric) = symmetric.(diag(parent(A)), sym_uplo(A.uplo))
diag(A::Hermitian) = hermitian.(diag(parent(A)), sym_uplo(A.uplo))

isdiag(A::HermOrSym) = isdiag(A.uplo == 'U' ? UpperTriangular(A.data) : LowerTriangular(A.data))

# For A<:Union{Symmetric,Hermitian}, similar(A[, neweltype]) should yield a matrix with the same
# symmetry type, uplo flag, and underlying storage type as A. The following methods cover these cases.
similar(A::Symmetric, ::Type{T}) where {T} = Symmetric(similar(parent(A), T), ifelse(A.uplo == 'U', :U, :L))
Expand Down Expand Up @@ -316,6 +318,7 @@ function fillstored!(A::HermOrSym{T}, x) where T
return A
end

Base.isreal(A::HermOrSym{<:Real}) = true
function Base.isreal(A::HermOrSym)
n = size(A, 1)
@inbounds if A.uplo == 'U'
Expand Down Expand Up @@ -578,9 +581,11 @@ end

function dot(x::AbstractVector, A::RealHermSymComplexHerm, y::AbstractVector)
require_one_based_indexing(x, y)
(length(x) == length(y) == size(A, 1)) || throw(DimensionMismatch())
n = length(x)
(n == length(y) == size(A, 1)) || throw(DimensionMismatch())
data = A.data
r = zero(eltype(x)) * zero(eltype(A)) * zero(eltype(y))
r = dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
iszero(n) && return r
if A.uplo == 'U'
@inbounds for j = 1:length(y)
r += dot(x[j], real(data[j,j]), y[j])
Expand Down Expand Up @@ -612,7 +617,9 @@ end
factorize(A::HermOrSym) = _factorize(A)
function _factorize(A::HermOrSym{T}; check::Bool=true) where T
TT = typeof(sqrt(oneunit(T)))
if TT <: BlasFloat
if isdiag(A)
return Diagonal(A)
elseif TT <: BlasFloat
return bunchkaufman(A; check=check)
else # fallback
return lu(A; check=check)
Expand All @@ -626,7 +633,7 @@ det(A::Symmetric) = det(_factorize(A; check=false))
\(A::HermOrSym, B::AbstractVector) = \(factorize(A), B)
# Bunch-Kaufman solves can not utilize BLAS-3 for multiple right hand sides
# so using LU is faster for AbstractMatrix right hand side
\(A::HermOrSym, B::AbstractMatrix) = \(lu(A), B)
\(A::HermOrSym, B::AbstractMatrix) = \(isdiag(A) ? Diagonal(A) : lu(A), B)

function _inv(A::HermOrSym)
n = checksquare(A)
Expand Down
8 changes: 4 additions & 4 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,10 @@ function istriu(A::Union{UpperTriangular,UnitUpperTriangular}, k::Integer=0)
k <= 0 && return true
return _istriu(A, k)
end
istril(A::Adjoint) = istriu(A.parent)
istril(A::Transpose) = istriu(A.parent)
istriu(A::Adjoint) = istril(A.parent)
istriu(A::Transpose) = istril(A.parent)
istril(A::Adjoint, k::Integer=0) = istriu(A.parent, -k)
istril(A::Transpose, k::Integer=0) = istriu(A.parent, -k)
istriu(A::Adjoint, k::Integer=0) = istril(A.parent, -k)
istriu(A::Transpose, k::Integer=0) = istril(A.parent, -k)

function tril!(A::UpperTriangular, k::Integer=0)
n = size(A,1)
Expand Down
5 changes: 5 additions & 0 deletions stdlib/LinearAlgebra/test/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ let n = 10
A = Areal
H = UpperHessenberg(A)
AH = triu(A,-1)
for k in -2:2
@test istril(H, k) == istril(AH, k)
@test istriu(H, k) == istriu(AH, k)
@test (k <= -1 ? istriu(H, k) : !istriu(H, k))
end
@test UpperHessenberg(H) === H
@test parent(H) === A
@test Matrix(H) == Array(H) == H == AH
Expand Down
18 changes: 16 additions & 2 deletions stdlib/LinearAlgebra/test/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,16 @@ end
end
@testset "diag" begin
D = Diagonal(x)
@test diag(Symmetric(D, :U))::Vector == x
@test diag(Hermitian(D, :U))::Vector == real(x)
DM = Matrix(D)
B = diagm(-1 => x, 1 => x)
for uplo in (:U, :L)
@test diag(Symmetric(D, uplo))::Vector == x
@test diag(Hermitian(D, uplo))::Vector == real(x)
@test isdiag(Symmetric(DM, uplo))
@test isdiag(Hermitian(DM, uplo))
@test !isdiag(Symmetric(B, uplo))
@test !isdiag(Hermitian(B, uplo))
end
end
@testset "similar" begin
@test isa(similar(Symmetric(asym)), Symmetric{eltya})
Expand Down Expand Up @@ -394,13 +402,19 @@ end
@test Hermitian(aherm)\b aherm\b
@test Symmetric(asym)\x asym\x
@test Symmetric(asym)\b asym\b
@test Hermitian(Diagonal(aherm))\x Diagonal(aherm)\x
@test Hermitian(Matrix(Diagonal(aherm)))\b Diagonal(aherm)\b
@test Symmetric(Diagonal(asym))\x Diagonal(asym)\x
@test Symmetric(Matrix(Diagonal(asym)))\b Diagonal(asym)\b
end
end
@testset "generalized dot product" begin
for uplo in (:U, :L)
@test dot(x, Hermitian(aherm, uplo), y) dot(x, Hermitian(aherm, uplo)*y) dot(x, Matrix(Hermitian(aherm, uplo)), y)
@test dot(x, Hermitian(aherm, uplo), x) dot(x, Hermitian(aherm, uplo)*x) dot(x, Matrix(Hermitian(aherm, uplo)), x)
end
@test dot(x, Hermitian(Diagonal(a)), y) dot(x, Hermitian(Diagonal(a))*y) dot(x, Matrix(Hermitian(Diagonal(a))), y)
@test dot(x, Hermitian(Diagonal(a)), x) dot(x, Hermitian(Diagonal(a))*x) dot(x, Matrix(Hermitian(Diagonal(a))), x)
if eltya <: Real
for uplo in (:U, :L)
@test dot(x, Symmetric(aherm, uplo), y) dot(x, Symmetric(aherm, uplo)*y) dot(x, Matrix(Symmetric(aherm, uplo)), y)
Expand Down
10 changes: 10 additions & 0 deletions stdlib/LinearAlgebra/test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ for elty1 in (Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFlo
@test !istriu(A1')
@test !istriu(transpose(A1))
end
M = copy(parent(A1))
for trans in (adjoint, transpose), k in -1:1
triu!(M, k)
@test istril(trans(M), -k) == istril(copy(trans(M)), -k) == true
end
M = copy(parent(A1))
for trans in (adjoint, transpose), k in 1:-1:-1
tril!(M, k)
@test istriu(trans(M), -k) == istriu(copy(trans(M)), -k) == true
end

#tril/triu
if uplo1 === :L
Expand Down

0 comments on commit 6ab660d

Please sign in to comment.