Skip to content

Commit

Permalink
some progress
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Dec 26, 2024
1 parent 6a9b74b commit a323813
Show file tree
Hide file tree
Showing 9 changed files with 343 additions and 355 deletions.
9 changes: 6 additions & 3 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
module MatrixAlgebraKit

using LinearAlgebra: LinearAlgebra
using LinearAlgebra: Diagonal
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt, triu!

export qr_compact!, qr_full!
export eigh_full!, eigh_vals!, eigh_trunc!
export svd_compact!, svd_full!, svd_vals!, svd_trunc!
# export eigh_full!, eigh_vals!, eigh_trunc!
export truncrank, trunctol, TruncationKeepSorted, TruncationKeepFiltered

include("auxiliary.jl")
include("backend.jl")
include("algorithms.jl")
include("truncation.jl")
include("yalapack.jl")
include("qr.jl")
include("svd.jl")
include("eigh.jl")
# include("eigh.jl")

end
32 changes: 32 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
struct Algorithm{name,K}
kwargs::K
end

macro algdef(name)
esc(quote
const $name{K} = Algorithm{$(QuoteNode(name)),K}
export $name
function $name(; kwargs...)
return $name{typeof(kwargs)}(kwargs)
end
function Base.print(io::IO, alg::$name)
print(io, $name, "(")
next = iterate(alg.kwargs)
isnothing(next) && return print(io, ")")
(k, v), state = next
print(io, "; ", string(k), "=", string(v))
next = iterate(alg.kwargs, state)
while !isnothing(next)
(k, v), state = next
print(io, ", ", string(k), "=", string(v))
next = iterate(alg.kwargs, state)
end
return print(io, ")")
end
end)
end

@algdef LAPACK_QRIteration
@algdef LAPACK_DivideAndConquer
@algdef LAPACK_RobustRepresentations
@algdef LAPACK_HouseholderQR
154 changes: 72 additions & 82 deletions src/eigh.jl
Original file line number Diff line number Diff line change
@@ -1,39 +1,45 @@
# TODO: do not export but mark as public ?
# TODO: export? or not export but mark as public ?
function eigh!(A::AbstractMatrix, args...; kwargs...)
return eigh_full!(A, args...; kwargs...)
end

function eigh_full!(A::AbstractMatrix,
D::AbstractVector=similar(A, real(eltype(A)), size(A, 1)),
V::AbstractMatrix=similar(A, size(A));
kwargs...)
return eigh_full!(A, D, V, default_backend(eigh_full!, A; kwargs...); kwargs...)
function eigh_full!(A::AbstractMatrix, DV=eigh_full_init(A); kwargs...)
return eigh_full!(A, DV, default_algorithm(eigh_full!, A; kwargs...))
end
function eigh_vals!(A::AbstractMatrix,
D::AbstractVector=similar(A, real(eltype(A)), size(A, 1));
kwargs...)
return eigh_vals!(A, D, default_backend(eigh_vals!, A; kwargs...); kwargs...)
function eigh_vals!(A::AbstractMatrix, D=eigh_vals_init(A); kwargs...)
return eigh_vals!(A, D, default_algorithm(eigh_vals!, A; kwargs...))
end
function eigh_trunc!(A::AbstractMatrix;
kwargs...)
return eigh_trunc!(A, default_backend(eigh_trunc!, A; kwargs...); kwargs...)
function eigh_trunc!(A::AbstractMatrix; kwargs...)
return eigh_trunc!(A, default_algorithm(eigh_trunc!, A; kwargs...))
end

function default_backend(::typeof(eigh_full!), A::AbstractMatrix; kwargs...)
return default_eigh_backend(A; kwargs...)
function eigh_full_init(A::AbstractMatrix)
n = size(A, 1) # square check will happen later
D = similar(A, real(eltype(A)), n)
V = similar(A, (n, n))
return (D, V)
end
function default_backend(::typeof(eigh_vals!), A::AbstractMatrix; kwargs...)
return default_eigh_backend(A; kwargs...)
function eigh_vals_init(A::AbstractMatrix)
n = size(A, 1) # square check will happen later
D = similar(A, real(eltype(A)), n)
return D
end

function default_algorithm(::typeof(eigh_full!), A::AbstractMatrix; kwargs...)
return default_eigh_algorithm(A; kwargs...)
end
function default_backend(::typeof(eigh_trunc!), A::AbstractMatrix; kwargs...)
return default_eigh_backend(A; kwargs...)
function default_algorithm(::typeof(eigh_vals!), A::AbstractMatrix; kwargs...)
return default_eigh_algorithm(A; kwargs...)
end
function default_algorithm(::typeof(eigh_trunc!), A::AbstractMatrix; kwargs...)
return default_eigh_algorithm(A; kwargs...)
end

function default_eigh_backend(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
return LAPACKBackend()
function default_eigh_algorithm(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
return LAPACK_RobustRepresentations(; kwargs...)
end

function check_eigh_full_input(A, D, V)
function check_eigh_full_input(A::AbstractMatrix, (D, V))
m, n = size(A)
m == n || throw(ArgumentError("Eigenvalue decompsition requires square matrix"))
size(D) == (n,) ||
Expand All @@ -42,82 +48,66 @@ function check_eigh_full_input(A, D, V)
throw(DimensionMismatch("Eigenvector matrix `V` must have size equal to A"))
return nothing
end
function check_eigh_vals_input(A, D)
function check_eigh_vals_input(A::AbstractMatrix, (D, V))
m, n = size(A)
m == n || throw(ArgumentError("Eigenvalue decompsition requires square matrix"))
size(D) == (n,) ||
throw(DimensionMismatch("Eigenvalue vector `D` must have length equal to size(A, 1)"))
return nothing
end

@static if VERSION >= v"1.12-DEV.0"
const RobustRepresentations = LinearAlgebra.RobustRepresentations
else
struct RobustRepresentations end
end

function eigh_full!(A::AbstractMatrix,
D::AbstractVector,
V::AbstractMatrix,
backend::LAPACKBackend;
alg=RobustRepresentations(),
kwargs...)
check_eigh_full_input(A, D, V)
if alg == RobustRepresentations()
YALAPACK.heevr!(A, D, V; kwargs...)
elseif alg == LinearAlgebra.DivideAndConquer()
YALAPACK.heevd!(A, D, V; kwargs...)
elseif alg == LinearAlgebra.QRIteration()
YALAPACK.heev!(A, D, V; kwargs...)
const LAPACK_EighAlgorithm = Union{LAPACK_RobustRepresentations,LAPACK_QRIteration,
LAPACK_DivideAndConquer}
function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm)
check_eigh_full_input(A, DV)
D, V = DV
if alg isa LAPACK_RobustRepresentations
YALAPACK.heevr!(A, D, V; alg.kwargs...)
elseif alg isa LAPACK_DivideAndConquer
YALAPACK.heevd!(A, D, V; alg.kwargs...)
else
throw(ArgumentError("Unknown LAPACK eigenvalue algorithm $alg"))
YALAPACK.heev!(A, D, V; alg.kwargs...)
end
return D, V
end

function eigh_vals!(A::AbstractMatrix,
D::AbstractVector,
backend::LAPACKBackend;
alg=RobustRepresentations(),
kwargs...)
function eigh_vals!(A::AbstractMatrix, D, alg::LAPACK_EighAlgorithm)
check_eigh_vals_input(A, D)
V = similar(A, (size(A, 1), 0))
if alg == RobustRepresentations()
YALAPACK.heevr!(A, D, V; kwargs...)
elseif alg == LinearAlgebra.DivideAndConquer()
YALAPACK.heevd!(A, D, V; kwargs...)
elseif alg == LinearAlgebra.QRIteration()
YALAPACK.heev!(A, D, V; kwargs...)
if alg isa LAPACK_RobustRepresentations
YALAPACK.heevr!(A, D, V; alg.kwargs...)
elseif alg isa LAPACK_DivideAndConquer
YALAPACK.heevd!(A, D, V; alg.kwargs...)
else
throw(ArgumentError("Unknown LAPACK eigenvalue algorithm $alg"))
YALAPACK.heev!(A, D, V; alg.kwargs...)
end
return D
return D, V
end

# for eigh_trunc!, it doesn't make sense to preallocate D and V as we don't know their sizes
function eigh_trunc!(A::AbstractMatrix,
backend::LAPACKBackend;
alg=RobustRepresentations(),
atol=zero(real(eltype(A))),
rtol=zero(real(eltype(A))),
rank=size(A, 1),
kwargs...)
if alg == RobustRepresentations()
D, V = YALAPACK.heevr!(A; kwargs...)
elseif alg == LinearAlgebra.DivideAndConquer()
D, V = YALAPACK.heevd!(A; kwargs...)
elseif alg == LinearAlgebra.QRIteration()
D, V = YALAPACK.heev!(A; kwargs...)
else
throw(ArgumentError("Unknown LAPACK eigenvalue algorithm $alg"))
end
# eigenvalues are sorted in ascending order
# TODO: do we assume that they are positive, or should we check for this?
# or do we want to truncate based on absolute value and thus sort differently?
n = length(D)
tol = convert(eltype(D), max(atol, rtol * D[n]))
s = max(n - rank + 1, findfirst(>=(tol), D))
# TODO: do we want views here, such that we do not need extra allocations if we later
# copy them into other storage
return D[n:-1:s], V[:, n:-1:s]
end
# function eigh_trunc!(A::AbstractMatrix,
# backend::LAPACKBackend;
# alg=RobustRepresentations(),
# atol=zero(real(eltype(A))),
# rtol=zero(real(eltype(A))),
# rank=size(A, 1),
# kwargs...)
# if alg == RobustRepresentations()
# D, V = YALAPACK.heevr!(A; kwargs...)
# elseif alg == LinearAlgebra.DivideAndConquer()
# D, V = YALAPACK.heevd!(A; kwargs...)
# elseif alg == LinearAlgebra.QRIteration()
# D, V = YALAPACK.heev!(A; kwargs...)
# else
# throw(ArgumentError("Unknown LAPACK eigenvalue algorithm $alg"))
# end
# # eigenvalues are sorted in ascending order
# # TODO: do we assume that they are positive, or should we check for this?
# # or do we want to truncate based on absolute value and thus sort differently?
# n = length(D)
# tol = convert(eltype(D), max(atol, rtol * D[n]))
# s = max(n - rank + 1, findfirst(>=(tol), D))
# # TODO: do we want views here, such that we do not need extra allocations if we later
# # copy them into other storage
# return D[n:-1:s], V[:, n:-1:s]
# end
86 changes: 43 additions & 43 deletions src/qr.jl
Original file line number Diff line number Diff line change
@@ -1,72 +1,72 @@
function qr_full!(A::AbstractMatrix,
Q::AbstractMatrix=similar(A, (size(A, 1), size(A, 1))),
R::AbstractMatrix=similar(A, (size(A, 1), size(A, 2)));
kwargs...)
return qr_full!(A, Q, R, default_backend(qr_full!, A; kwargs...); kwargs...)
function qr_full!(A::AbstractMatrix, QR=qr_full_init(A); kwargs...)
return qr_full!(A, QR, default_algorithm(qr_full!, A; kwargs...))
end
function qr_compact!(A::AbstractMatrix,
Q::AbstractMatrix=similar(A, (size(A, 1), size(A, 1))),
R::AbstractMatrix=similar(A, (size(A, 1), size(A, 2)));
kwargs...)
return qr_compact!(A, Q, R, default_backend(qr_compact!, A; kwargs...); kwargs...)
function qr_compact!(A::AbstractMatrix, QR=qr_compact_init(A); kwargs...)
return qr_compact!(A, QR, default_algorithm(qr_compact!, A; kwargs...))
end

function default_backend(::typeof(qr_full!), A::AbstractMatrix; kwargs...)
return default_qr_backend(A; kwargs...)
function qr_full_init(A::AbstractMatrix)
m, n = size(A)
Q = similar(A, (m, m))
R = similar(A, (m, n))
return (Q, R)
end
function qr_compact_init(A::AbstractMatrix)
m, n = size(A)
minmn = min(m, n)
Q = similar(A, (m, minmn))
R = similar(A, (minmn, n))
return (Q, R)
end

function default_algorithm(::typeof(qr_full!), A::AbstractMatrix; kwargs...)
return default_qr_algorithm(A; kwargs...)
end
function default_backend(::typeof(qr_compact!), A::AbstractMatrix; kwargs...)
return default_qr_backend(A; kwargs...)
function default_algorithm(::typeof(qr_compact!), A::AbstractMatrix; kwargs...)
return default_qr_algorithm(A; kwargs...)
end

function default_qr_backend(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
return LAPACKBackend()
function default_qr_algorithm(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
return LAPACK_HouseholderQR(; kwargs...)
end

function check_qr_full_input(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix)
function check_qr_full_input(A::AbstractMatrix, QR)
m, n = size(A)
size(Q) == (m, m) ||
Q, R = QR
(Q isa AbstractMatrix && eltype(Q) == eltype(A) && size(Q) == (m, m)) ||
throw(DimensionMismatch("Full unitary matrix `Q` must be square with equal number of rows as A"))
isempty(R) || size(R) == (m, n) ||
(R isa AbstractMatrix && eltype(R) == eltype(A) && (isempty(R) || size(R) == (m, n))) ||
throw(DimensionMismatch("Upper triangular matrix `R` must have size equal to A"))
return nothing
end
function check_qr_compact_input(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix)
function check_qr_compact_input(A::AbstractMatrix, QR)
m, n = size(A)
if n <= m
size(Q) == (m, n) ||
Q, R = QR
(Q isa AbstractMatrix && eltype(Q) == eltype(A) && size(Q) == (m, n)) ||
throw(DimensionMismatch("Isometric `Q` must have size equal to A"))
isempty(R) || size(R) == (n, n) ||
(R isa AbstractMatrix && eltype(R) == eltype(A) &&
(isempty(R) || size(R) == (n, n))) ||
throw(DimensionMismatch("Upper triangular matrix `R` must be square with equal number of columns as A"))
else
check_qr_full_input(A, Q, R)
check_qr_full_input(A, QR)
end
end

function qr_full!(A::AbstractMatrix,
Q::AbstractMatrix,
R::AbstractMatrix,
backend::LAPACKBackend;
positive=false,
pivoted=false,
blocksize=((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A)))
check_qr_full_input(A, Q, R)
_unsafe_qr!(A, Q, R; positive=positive, pivoted=pivoted, blocksize=blocksize)
function qr_full!(A::AbstractMatrix, QR, alg::LAPACK_HouseholderQR)
check_qr_full_input(A, QR)
Q, R = QR
_lapack_qr!(A, Q, R; alg.kwargs...)
return Q, R
end

function qr_compact!(A::AbstractMatrix,
Q::AbstractMatrix,
R::AbstractMatrix,
backend::LAPACKBackend;
positive=false,
pivoted=false,
blocksize=((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A)))
check_qr_compact_input(A, Q, R)
_unsafe_qr!(A, Q, R; positive=positive, pivoted=pivoted, blocksize=blocksize)
function qr_compact!(A::AbstractMatrix, QR, alg::LAPACK_HouseholderQR)
check_qr_compact_input(A, QR)
Q, R = QR
_lapack_qr!(A, Q, R; alg.kwargs...)
return Q, R
end

function _unsafe_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
function _lapack_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
positive=false,
pivoted=false,
blocksize=((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A)))
Expand Down
Loading

0 comments on commit a323813

Please sign in to comment.