diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 7a92e19..d4d9c58 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -1,17 +1,20 @@ module MatrixAlgebraKit using LinearAlgebra: LinearAlgebra +using LinearAlgebra: Diagonal using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt, triu! export qr_compact!, qr_full! -export eigh_full!, eigh_vals!, eigh_trunc! export svd_compact!, svd_full!, svd_vals!, svd_trunc! +# export eigh_full!, eigh_vals!, eigh_trunc! +export truncrank, trunctol, TruncationKeepSorted, TruncationKeepFiltered include("auxiliary.jl") -include("backend.jl") +include("algorithms.jl") +include("truncation.jl") include("yalapack.jl") include("qr.jl") include("svd.jl") -include("eigh.jl") +# include("eigh.jl") end diff --git a/src/algorithms.jl b/src/algorithms.jl new file mode 100644 index 0000000..8214c5a --- /dev/null +++ b/src/algorithms.jl @@ -0,0 +1,32 @@ +struct Algorithm{name,K} + kwargs::K +end + +macro algdef(name) + esc(quote + const $name{K} = Algorithm{$(QuoteNode(name)),K} + export $name + function $name(; kwargs...) + return $name{typeof(kwargs)}(kwargs) + end + function Base.print(io::IO, alg::$name) + print(io, $name, "(") + next = iterate(alg.kwargs) + isnothing(next) && return print(io, ")") + (k, v), state = next + print(io, "; ", string(k), "=", string(v)) + next = iterate(alg.kwargs, state) + while !isnothing(next) + (k, v), state = next + print(io, ", ", string(k), "=", string(v)) + next = iterate(alg.kwargs, state) + end + return print(io, ")") + end + end) +end + +@algdef LAPACK_QRIteration +@algdef LAPACK_DivideAndConquer +@algdef LAPACK_RobustRepresentations +@algdef LAPACK_HouseholderQR diff --git a/src/eigh.jl b/src/eigh.jl index 7ff1915..3327151 100644 --- a/src/eigh.jl +++ b/src/eigh.jl @@ -1,39 +1,45 @@ -# TODO: do not export but mark as public ? +# TODO: export? or not export but mark as public ? function eigh!(A::AbstractMatrix, args...; kwargs...) return eigh_full!(A, args...; kwargs...) end -function eigh_full!(A::AbstractMatrix, - D::AbstractVector=similar(A, real(eltype(A)), size(A, 1)), - V::AbstractMatrix=similar(A, size(A)); - kwargs...) - return eigh_full!(A, D, V, default_backend(eigh_full!, A; kwargs...); kwargs...) +function eigh_full!(A::AbstractMatrix, DV=eigh_full_init(A); kwargs...) + return eigh_full!(A, DV, default_algorithm(eigh_full!, A; kwargs...)) end -function eigh_vals!(A::AbstractMatrix, - D::AbstractVector=similar(A, real(eltype(A)), size(A, 1)); - kwargs...) - return eigh_vals!(A, D, default_backend(eigh_vals!, A; kwargs...); kwargs...) +function eigh_vals!(A::AbstractMatrix, D=eigh_vals_init(A); kwargs...) + return eigh_vals!(A, D, default_algorithm(eigh_vals!, A; kwargs...)) end -function eigh_trunc!(A::AbstractMatrix; - kwargs...) - return eigh_trunc!(A, default_backend(eigh_trunc!, A; kwargs...); kwargs...) +function eigh_trunc!(A::AbstractMatrix; kwargs...) + return eigh_trunc!(A, default_algorithm(eigh_trunc!, A; kwargs...)) end -function default_backend(::typeof(eigh_full!), A::AbstractMatrix; kwargs...) - return default_eigh_backend(A; kwargs...) +function eigh_full_init(A::AbstractMatrix) + n = size(A, 1) # square check will happen later + D = similar(A, real(eltype(A)), n) + V = similar(A, (n, n)) + return (D, V) end -function default_backend(::typeof(eigh_vals!), A::AbstractMatrix; kwargs...) - return default_eigh_backend(A; kwargs...) +function eigh_vals_init(A::AbstractMatrix) + n = size(A, 1) # square check will happen later + D = similar(A, real(eltype(A)), n) + return D +end + +function default_algorithm(::typeof(eigh_full!), A::AbstractMatrix; kwargs...) + return default_eigh_algorithm(A; kwargs...) end -function default_backend(::typeof(eigh_trunc!), A::AbstractMatrix; kwargs...) - return default_eigh_backend(A; kwargs...) +function default_algorithm(::typeof(eigh_vals!), A::AbstractMatrix; kwargs...) + return default_eigh_algorithm(A; kwargs...) +end +function default_algorithm(::typeof(eigh_trunc!), A::AbstractMatrix; kwargs...) + return default_eigh_algorithm(A; kwargs...) end -function default_eigh_backend(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat} - return LAPACKBackend() +function default_eigh_algorithm(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat} + return LAPACK_RobustRepresentations(; kwargs...) end -function check_eigh_full_input(A, D, V) +function check_eigh_full_input(A::AbstractMatrix, (D, V)) m, n = size(A) m == n || throw(ArgumentError("Eigenvalue decompsition requires square matrix")) size(D) == (n,) || @@ -42,7 +48,7 @@ function check_eigh_full_input(A, D, V) throw(DimensionMismatch("Eigenvector matrix `V` must have size equal to A")) return nothing end -function check_eigh_vals_input(A, D) +function check_eigh_vals_input(A::AbstractMatrix, (D, V)) m, n = size(A) m == n || throw(ArgumentError("Eigenvalue decompsition requires square matrix")) size(D) == (n,) || @@ -50,74 +56,58 @@ function check_eigh_vals_input(A, D) return nothing end -@static if VERSION >= v"1.12-DEV.0" - const RobustRepresentations = LinearAlgebra.RobustRepresentations -else - struct RobustRepresentations end -end - -function eigh_full!(A::AbstractMatrix, - D::AbstractVector, - V::AbstractMatrix, - backend::LAPACKBackend; - alg=RobustRepresentations(), - kwargs...) - check_eigh_full_input(A, D, V) - if alg == RobustRepresentations() - YALAPACK.heevr!(A, D, V; kwargs...) - elseif alg == LinearAlgebra.DivideAndConquer() - YALAPACK.heevd!(A, D, V; kwargs...) - elseif alg == LinearAlgebra.QRIteration() - YALAPACK.heev!(A, D, V; kwargs...) +const LAPACK_EighAlgorithm = Union{LAPACK_RobustRepresentations,LAPACK_QRIteration, + LAPACK_DivideAndConquer} +function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm) + check_eigh_full_input(A, DV) + D, V = DV + if alg isa LAPACK_RobustRepresentations + YALAPACK.heevr!(A, D, V; alg.kwargs...) + elseif alg isa LAPACK_DivideAndConquer + YALAPACK.heevd!(A, D, V; alg.kwargs...) else - throw(ArgumentError("Unknown LAPACK eigenvalue algorithm $alg")) + YALAPACK.heev!(A, D, V; alg.kwargs...) end return D, V end -function eigh_vals!(A::AbstractMatrix, - D::AbstractVector, - backend::LAPACKBackend; - alg=RobustRepresentations(), - kwargs...) +function eigh_vals!(A::AbstractMatrix, D, alg::LAPACK_EighAlgorithm) check_eigh_vals_input(A, D) V = similar(A, (size(A, 1), 0)) - if alg == RobustRepresentations() - YALAPACK.heevr!(A, D, V; kwargs...) - elseif alg == LinearAlgebra.DivideAndConquer() - YALAPACK.heevd!(A, D, V; kwargs...) - elseif alg == LinearAlgebra.QRIteration() - YALAPACK.heev!(A, D, V; kwargs...) + if alg isa LAPACK_RobustRepresentations + YALAPACK.heevr!(A, D, V; alg.kwargs...) + elseif alg isa LAPACK_DivideAndConquer + YALAPACK.heevd!(A, D, V; alg.kwargs...) else - throw(ArgumentError("Unknown LAPACK eigenvalue algorithm $alg")) + YALAPACK.heev!(A, D, V; alg.kwargs...) end - return D + return D, V end # for eigh_trunc!, it doesn't make sense to preallocate D and V as we don't know their sizes -function eigh_trunc!(A::AbstractMatrix, - backend::LAPACKBackend; - alg=RobustRepresentations(), - atol=zero(real(eltype(A))), - rtol=zero(real(eltype(A))), - rank=size(A, 1), - kwargs...) - if alg == RobustRepresentations() - D, V = YALAPACK.heevr!(A; kwargs...) - elseif alg == LinearAlgebra.DivideAndConquer() - D, V = YALAPACK.heevd!(A; kwargs...) - elseif alg == LinearAlgebra.QRIteration() - D, V = YALAPACK.heev!(A; kwargs...) - else - throw(ArgumentError("Unknown LAPACK eigenvalue algorithm $alg")) - end - # eigenvalues are sorted in ascending order - # TODO: do we assume that they are positive, or should we check for this? - # or do we want to truncate based on absolute value and thus sort differently? - n = length(D) - tol = convert(eltype(D), max(atol, rtol * D[n])) - s = max(n - rank + 1, findfirst(>=(tol), D)) - # TODO: do we want views here, such that we do not need extra allocations if we later - # copy them into other storage - return D[n:-1:s], V[:, n:-1:s] -end +# function eigh_trunc!(A::AbstractMatrix, +# backend::LAPACKBackend; +# alg=RobustRepresentations(), +# atol=zero(real(eltype(A))), +# rtol=zero(real(eltype(A))), +# rank=size(A, 1), +# kwargs...) +# if alg == RobustRepresentations() +# D, V = YALAPACK.heevr!(A; kwargs...) +# elseif alg == LinearAlgebra.DivideAndConquer() +# D, V = YALAPACK.heevd!(A; kwargs...) +# elseif alg == LinearAlgebra.QRIteration() +# D, V = YALAPACK.heev!(A; kwargs...) +# else +# throw(ArgumentError("Unknown LAPACK eigenvalue algorithm $alg")) +# end +# # eigenvalues are sorted in ascending order +# # TODO: do we assume that they are positive, or should we check for this? +# # or do we want to truncate based on absolute value and thus sort differently? +# n = length(D) +# tol = convert(eltype(D), max(atol, rtol * D[n])) +# s = max(n - rank + 1, findfirst(>=(tol), D)) +# # TODO: do we want views here, such that we do not need extra allocations if we later +# # copy them into other storage +# return D[n:-1:s], V[:, n:-1:s] +# end diff --git a/src/qr.jl b/src/qr.jl index d01e508..7197222 100644 --- a/src/qr.jl +++ b/src/qr.jl @@ -1,72 +1,72 @@ -function qr_full!(A::AbstractMatrix, - Q::AbstractMatrix=similar(A, (size(A, 1), size(A, 1))), - R::AbstractMatrix=similar(A, (size(A, 1), size(A, 2))); - kwargs...) - return qr_full!(A, Q, R, default_backend(qr_full!, A; kwargs...); kwargs...) +function qr_full!(A::AbstractMatrix, QR=qr_full_init(A); kwargs...) + return qr_full!(A, QR, default_algorithm(qr_full!, A; kwargs...)) end -function qr_compact!(A::AbstractMatrix, - Q::AbstractMatrix=similar(A, (size(A, 1), size(A, 1))), - R::AbstractMatrix=similar(A, (size(A, 1), size(A, 2))); - kwargs...) - return qr_compact!(A, Q, R, default_backend(qr_compact!, A; kwargs...); kwargs...) +function qr_compact!(A::AbstractMatrix, QR=qr_compact_init(A); kwargs...) + return qr_compact!(A, QR, default_algorithm(qr_compact!, A; kwargs...)) end -function default_backend(::typeof(qr_full!), A::AbstractMatrix; kwargs...) - return default_qr_backend(A; kwargs...) +function qr_full_init(A::AbstractMatrix) + m, n = size(A) + Q = similar(A, (m, m)) + R = similar(A, (m, n)) + return (Q, R) +end +function qr_compact_init(A::AbstractMatrix) + m, n = size(A) + minmn = min(m, n) + Q = similar(A, (m, minmn)) + R = similar(A, (minmn, n)) + return (Q, R) +end + +function default_algorithm(::typeof(qr_full!), A::AbstractMatrix; kwargs...) + return default_qr_algorithm(A; kwargs...) end -function default_backend(::typeof(qr_compact!), A::AbstractMatrix; kwargs...) - return default_qr_backend(A; kwargs...) +function default_algorithm(::typeof(qr_compact!), A::AbstractMatrix; kwargs...) + return default_qr_algorithm(A; kwargs...) end -function default_qr_backend(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat} - return LAPACKBackend() +function default_qr_algorithm(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat} + return LAPACK_HouseholderQR(; kwargs...) end -function check_qr_full_input(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix) +function check_qr_full_input(A::AbstractMatrix, QR) m, n = size(A) - size(Q) == (m, m) || + Q, R = QR + (Q isa AbstractMatrix && eltype(Q) == eltype(A) && size(Q) == (m, m)) || throw(DimensionMismatch("Full unitary matrix `Q` must be square with equal number of rows as A")) - isempty(R) || size(R) == (m, n) || + (R isa AbstractMatrix && eltype(R) == eltype(A) && (isempty(R) || size(R) == (m, n))) || throw(DimensionMismatch("Upper triangular matrix `R` must have size equal to A")) return nothing end -function check_qr_compact_input(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix) +function check_qr_compact_input(A::AbstractMatrix, QR) m, n = size(A) if n <= m - size(Q) == (m, n) || + Q, R = QR + (Q isa AbstractMatrix && eltype(Q) == eltype(A) && size(Q) == (m, n)) || throw(DimensionMismatch("Isometric `Q` must have size equal to A")) - isempty(R) || size(R) == (n, n) || + (R isa AbstractMatrix && eltype(R) == eltype(A) && + (isempty(R) || size(R) == (n, n))) || throw(DimensionMismatch("Upper triangular matrix `R` must be square with equal number of columns as A")) else - check_qr_full_input(A, Q, R) + check_qr_full_input(A, QR) end end -function qr_full!(A::AbstractMatrix, - Q::AbstractMatrix, - R::AbstractMatrix, - backend::LAPACKBackend; - positive=false, - pivoted=false, - blocksize=((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A))) - check_qr_full_input(A, Q, R) - _unsafe_qr!(A, Q, R; positive=positive, pivoted=pivoted, blocksize=blocksize) +function qr_full!(A::AbstractMatrix, QR, alg::LAPACK_HouseholderQR) + check_qr_full_input(A, QR) + Q, R = QR + _lapack_qr!(A, Q, R; alg.kwargs...) return Q, R end - -function qr_compact!(A::AbstractMatrix, - Q::AbstractMatrix, - R::AbstractMatrix, - backend::LAPACKBackend; - positive=false, - pivoted=false, - blocksize=((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A))) - check_qr_compact_input(A, Q, R) - _unsafe_qr!(A, Q, R; positive=positive, pivoted=pivoted, blocksize=blocksize) +function qr_compact!(A::AbstractMatrix, QR, alg::LAPACK_HouseholderQR) + check_qr_compact_input(A, QR) + Q, R = QR + _lapack_qr!(A, Q, R; alg.kwargs...) return Q, R end -function _unsafe_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; +function _lapack_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; positive=false, pivoted=false, blocksize=((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A))) diff --git a/src/svd.jl b/src/svd.jl index e97fe56..1faf3f7 100644 --- a/src/svd.jl +++ b/src/svd.jl @@ -3,144 +3,135 @@ function svd!(A::AbstractMatrix, args...; kwargs...) return svd_compact!(A, args...; kwargs...) end -function svd_full!(A::AbstractMatrix, - U::AbstractMatrix=similar(A, (size(A, 1), size(A, 1))), - S::AbstractVector=similar(A, real(eltype(A)), (min(size(A)...),)), - Vᴴ::AbstractMatrix=similar(A, (size(A, 2), size(A, 2))); - kwargs...) - return svd_full!(A, U, S, Vᴴ, default_backend(svd_full!, A; kwargs...); kwargs...) -end -function svd_compact!(A::AbstractMatrix, - U::AbstractMatrix=similar(A, (size(A, 1), size(A, 1))), - S::AbstractVector=similar(A, real(eltype(A)), (min(size(A)...),)), - Vᴴ::AbstractMatrix=similar(A, (size(A, 2), size(A, 2))); - kwargs...) - return svd_compact!(A, U, S, Vᴴ, default_backend(svd_compact!, A; kwargs...); kwargs...) -end -function svd_vals!(A::AbstractMatrix, - S::AbstractVector=similar(A, real(eltype(A)), (min(size(A)...),)); - kwargs...) - return svd_vals!(A, S, default_backend(svd_vals!, A; kwargs...); kwargs...) +function svd_full!(A::AbstractMatrix, USVᴴ=svd_full_init(A); kwargs...) + return svd_full!(A, USVᴴ, default_algorithm(svd_full!, A; kwargs...)) +end +function svd_compact!(A::AbstractMatrix, USVᴴ=svd_compact_init(A); kwargs...) + return svd_compact!(A, USVᴴ, default_algorithm(svd_compact!, A; kwargs...)) +end +function svd_vals!(A::AbstractMatrix, S=svd_vals_init(A); kwargs...) + return svd_vals!(A, S, default_algorithm(svd_vals!, A; kwargs...)) +end +function svd_trunc!(A::AbstractMatrix, trunc::TruncationStrategy; kwargs...) + return svd_trunc!(A, default_algorithm(svd_trunc!, A; kwargs...), trunc) end -function svd_trunc!(A::AbstractMatrix; - kwargs...) - return svd_trunc!(A, default_backend(svd_trunc!, A; kwargs...); kwargs...) +function svd_full_init(A::AbstractMatrix) + m, n = size(A) + minmn = min(m, n) + U = similar(A, (m, m)) + S = similar(A, real(eltype(A)), (m, n)) + Vᴴ = similar(A, (n, n)) + return (U, S, Vᴴ) +end +function svd_compact_init(A::AbstractMatrix) + m, n = size(A) + minmn = min(m, n) + U = similar(A, (m, minmn)) + S = Diagonal(similar(A, real(eltype(A)), (minmn,))) + Vᴴ = similar(A, (minmn, n)) + return (U, S, Vᴴ) +end +function svd_vals_init(A::AbstractMatrix) + return similar(A, real(eltype(A)), (min(size(A)...),)) end -function default_backend(::typeof(svd_full!), A::AbstractMatrix; kwargs...) - return default_svd_backend(A; kwargs...) +function default_algorithm(::typeof(svd_full!), A::AbstractMatrix; kwargs...) + return default_svd_algorithm(A; kwargs...) end -function default_backend(::typeof(svd_compact!), A::AbstractMatrix; kwargs...) - return default_svd_backend(A; kwargs...) +function default_algorithm(::typeof(svd_compact!), A::AbstractMatrix; kwargs...) + return default_svd_algorithm(A; kwargs...) end -function default_backend(::typeof(svd_vals!), A::AbstractMatrix; kwargs...) - return default_svd_backend(A; kwargs...) +function default_algorithm(::typeof(svd_vals!), A::AbstractMatrix; kwargs...) + return default_svd_algorithm(A; kwargs...) end -function default_backend(::typeof(svd_trunc!), A::AbstractMatrix; kwargs...) - return default_svd_backend(A; kwargs...) +function default_algorithm(::typeof(svd_trunc!), A::AbstractMatrix; kwargs...) + return default_svd_algorithm(A; kwargs...) end -function default_svd_backend(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat} - return LAPACKBackend() +function default_svd_algorithm(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat} + return LAPACK_DivideAndConquer(; kwargs...) end -function check_svd_full_input(A, U, S, Vᴴ) +function check_svd_full_input(A::AbstractMatrix, USVᴴ) m, n = size(A) minmn = min(m, n) - size(U) == (m, m) || - throw(DimensionMismatch("`svd_full!` requires square U matrix with equal number of rows as A")) - size(Vᴴ) == (n, n) || - throw(DimensionMismatch("`svd_full!` requires square Vᴴ matrix with equal number of columns as A")) - size(S) == (minmn,) || - throw(DimensionMismatch("`svd_full!` requires vector S of length min(size(A)...")) + U, S, Vᴴ = USVᴴ + (U isa AbstractMatrix && eltype(U) == eltype(A) && size(U) == (m, m)) || + throw(DimensionMismatch("`svd_full!` requires square U matrix with equal number of rows and same `eltype` as A")) + (Vᴴ isa AbstractMatrix && eltype(Vᴴ) == eltype(A) && size(Vᴴ) == (n, n)) || + throw(DimensionMismatch("`svd_full!` requires square Vᴴ matrix with equal number of columns and same `eltype` as A")) + (S isa AbstractMatrix && eltype(S) == real(eltype(A)) && size(S) == (m, n)) || + throw(DimensionMismatch("`svd_full!` requires a matrix S of the same size as A with a real `eltype`")) return nothing end -function check_svd_compact_input(A, U, S, Vᴴ) +function check_svd_compact_input(A::AbstractMatrix, USVᴴ) m, n = size(A) minmn = min(m, n) - size(U) == (m, minmn) || - throw(DimensionMismatch("`svd_compact!` requires square U matrix with equal number of rows as A")) - size(Vᴴ) == (minmn, n) || - throw(DimensionMismatch("`svd_compact!` requires square Vᴴ matrix with equal number of columns as A")) - size(S) == (minmn,) || - throw(DimensionMismatch("`svd_compact!` requires vector S of length min(size(A)...")) + U, S, Vᴴ = USVᴴ + (U isa AbstractMatrix && eltype(U) == eltype(A) && size(U) == (m, minmn)) || + throw(DimensionMismatch("`svd_full!` requires square U matrix with equal number of rows and same `eltype` as A")) + (Vᴴ isa AbstractMatrix && eltype(Vᴴ) == eltype(A) && size(Vᴴ) == (minmn, n)) || + throw(DimensionMismatch("`svd_full!` requires square Vᴴ matrix with equal number of columns and same `eltype` as A")) + (S isa Diagonal && eltype(S) == real(eltype(A)) && size(S) == (minmn, minmn)) || + throw(DimensionMismatch("`svd_compact!` requires Diagonal matrix S with number of rows equal to min(size(A)...) with a real `eltype`")) return nothing end -function check_svd_vals_input(A, S) +function check_svd_vals_input(A::AbstractMatrix, S) m, n = size(A) minmn = min(m, n) - size(S) == (minmn,) || - throw(DimensionMismatch("`svd_vals!` requires vector S of length min(size(A)...")) + (S isa AbstractVector && eltype(S) == real(eltype(A)) && size(S) == (minmn,)) || + throw(DimensionMismatch("`svd_vals!` requires vector S of length min(size(A)...) with a real `eltype`")) return nothing end -function svd_full!(A::AbstractMatrix, - U::AbstractMatrix, - S::AbstractVector, - Vᴴ::AbstractMatrix, - backend::LAPACKBackend; - alg=LinearAlgebra.DivideAndConquer()) - check_svd_full_input(A, U, S, Vᴴ) - if alg == LinearAlgebra.DivideAndConquer() - YALAPACK.gesdd!(A, S, U, Vᴴ) - elseif alg == LinearAlgebra.QRIteration() - YALAPACK.gesvd!(A, S, U, Vᴴ) +const LAPACK_SVDAlgorithm = Union{LAPACK_QRIteration,LAPACK_DivideAndConquer} + +function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) + check_svd_full_input(A, USVᴴ) + U, S, Vᴴ = USVᴴ + fill!(S, zero(eltype(S))) + minmn = min(size(A)...) + if alg isa LAPACK_QRIteration + YALAPACK.gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...) else - throw(ArgumentError("Unknown LAPACK singular value algorithm $alg")) + YALAPACK.gesdd!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...) + end + for i in 2:minmn + S[i, i] = S[i, 1] + S[i, 1] = zero(eltype(S)) end - return U, S, Vᴴ -end -function svd_compact!(A::AbstractMatrix, - U::AbstractMatrix, - S::AbstractVector, - Vᴴ::AbstractMatrix, - backend::LAPACKBackend; - alg=LinearAlgebra.DivideAndConquer()) - check_svd_compact_input(A, U, S, Vᴴ) - if alg == LinearAlgebra.DivideAndConquer() - YALAPACK.gesdd!(A, S, U, Vᴴ) - elseif alg == LinearAlgebra.QRIteration() - YALAPACK.gesvd!(A, S, U, Vᴴ) + return USVᴴ +end +function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) + check_svd_compact_input(A, USVᴴ) + U, S, Vᴴ = USVᴴ + if alg isa LAPACK_QRIteration + YALAPACK.gesvd!(A, S.diag, U, Vᴴ; alg.kwargs...) else - throw(ArgumentError("Unknown LAPACK singular value algorithm $alg")) + YALAPACK.gesdd!(A, S.diag, U, Vᴴ; alg.kwargs...) end - return U, S, Vᴴ + return USVᴴ end - -function svd_vals!(A::AbstractMatrix, - S::AbstractVector, - backend::LAPACKBackend; - alg=LinearAlgebra.DivideAndConquer()) +function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) check_svd_vals_input(A, S) - m, n = size(A) - if alg == LinearAlgebra.DivideAndConquer() - YALAPACK.gesdd!(A, S, similar(A, m, 0), similar(A, n, 0)) - elseif alg == LinearAlgebra.QRIteration() - YALAPACK.gesvd!(A, S, similar(A, m, 0), similar(A, n, 0)) + if alg isa LAPACK_QRIteration + YALAPACK.gesvd!(A, S; alg.kwargs...) else - throw(ArgumentError("Unknown LAPACK singular value algorithm $alg")) + YALAPACK.gesdd!(A, S; alg.kwargs...) end return S end -# for svd_trunc!, it doesn't make sense to preallocate U, S, Vᴴ as we don't know their sizes -function svd_trunc!(A::AbstractMatrix, - backend::LAPACKBackend; - alg=LinearAlgebra.DivideAndConquer(), - atol=zero(real(eltype(A))), - rtol=zero(real(eltype(A))), - rank=min(size(A)...)) - if alg == LinearAlgebra.DivideAndConquer() - S, U, Vᴴ = YALAPACK.gesdd!(A) - elseif alg == LinearAlgebra.QRIteration() - S, U, Vᴴ = YALAPACK.gesvd!(A) - else - throw(ArgumentError("Unknown LAPACK singular value algorithm $alg")) - end - tol = convert(eltype(S), max(atol, rtol * S[1])) - r = min(rank, findlast(>=(tol), S)) - # TODO: do we want views here, such that we do not need extra allocations if we later - # copy them into other storage - return U[:, 1:r], S[1:r], Vᴴ[1:r, :] +# # for svd_trunc!, it doesn't make sense to preallocate U, S, Vᴴ as we don't know their sizes +function svd_trunc!(A::AbstractMatrix, alg::LAPACK_SVDAlgorithm, trunc::TruncationStrategy) + USVᴴ = svd_compact_init(A) + U, S, Vᴴ = svd_compact!(A, USVᴴ, alg) + + Sd = S.diag + ind = findtruncated(Sd, trunc) + U′ = U[:, ind] + S′ = Diagonal(Sd[ind]) + Vᴴ′ = Vᴴ[ind, :] + return (U′, S′, Vᴴ′) end diff --git a/src/truncation.jl b/src/truncation.jl new file mode 100644 index 0000000..adaa1e8 --- /dev/null +++ b/src/truncation.jl @@ -0,0 +1,37 @@ +abstract type TruncationStrategy end + +""" + TruncationKeepSorted(howmany::Int, sortby::Function, rev::Bool) + +Truncation strategy to keep the first `howmany` values when sorted according to `sortby` or the last `howmany` if `rev` is true. +""" +struct TruncationKeepSorted{F} <: TruncationStrategy + howmany::Int + sortby::F + rev::Bool +end + +""" + TruncationKeepFiltered(filter::Function) + +Truncation strategy to keep the values for which `filter` returns true. +""" +struct TruncationKeepFiltered{F} <: TruncationStrategy + filter::F +end + +# TODO: better names for these functions of the above types +truncrank(howmany::Int, by=abs, rev=true) = TruncationKeepSorted(howmany, by, rev) +trunctol(atol) = TruncationKeepFiltered(>=(atol)) + +function findtruncated(values::AbstractVector, strategy::TruncationKeepSorted) + sorted = sortperm(values; by=strategy.sortby, rev=strategy.rev) + howmany = min(strategy.howmany, length(sorted)) + ind = sorted[1:howmany] + return ind +end + +function findtruncated(values::AbstractVector, strategy::TruncationKeepFiltered) + ind = findall(strategy.filter, values) + return ind +end \ No newline at end of file diff --git a/test/qr.jl b/test/qr.jl index 56e3761..8747339 100644 --- a/test/qr.jl +++ b/test/qr.jl @@ -7,49 +7,49 @@ R = similar(A, min(m, n), n) Q2 = similar(Q) noR = similar(A, min(m, n), 0) - qr_compact!(copy!(Ac, A), Q, R) + qr_compact!(copy!(Ac, A), (Q, R)) @test Q * R ≈ A @test Q' * Q ≈ I - qr_compact!(copy!(Ac, A), Q2, noR) + qr_compact!(copy!(Ac, A), (Q2, noR)) @test Q == Q2 # unblocked algorithm - qr_compact!(copy!(Ac, A), Q, R; blocksize=1) + qr_compact!(copy!(Ac, A), (Q, R); blocksize=1) @test Q * R ≈ A @test Q' * Q ≈ I - qr_compact!(copy!(Ac, A), Q2, noR; blocksize=1) + qr_compact!(copy!(Ac, A), (Q2, noR); blocksize=1) @test Q == Q2 if n <= m - qr_compact!(copy!(Q2, A), Q2, noR; blocksize=1) # in-place Q + qr_compact!(copy!(Q2, A), (Q2, noR); blocksize=1) # in-place Q @test Q ≈ Q2 end # other blocking - qr_compact!(copy!(Ac, A), Q, R; blocksize=8) + qr_compact!(copy!(Ac, A), (Q, R); blocksize=8) @test Q * R ≈ A @test Q' * Q ≈ I - qr_compact!(copy!(Ac, A), Q2, noR; blocksize=8) + qr_compact!(copy!(Ac, A), (Q2, noR); blocksize=8) @test Q == Q2 # pivoted - qr_compact!(copy!(Ac, A), Q, R; pivoted=true) + qr_compact!(copy!(Ac, A), (Q, R); pivoted=true) @test Q * R ≈ A @test Q' * Q ≈ I - qr_compact!(copy!(Ac, A), Q2, noR; pivoted=true) + qr_compact!(copy!(Ac, A), (Q2, noR); pivoted=true) @test Q == Q2 # positive - qr_compact!(copy!(Ac, A), Q, R; positive=true) + qr_compact!(copy!(Ac, A), (Q, R); positive=true) @test Q * R ≈ A @test Q' * Q ≈ I @test all(>=(zero(real(T))), real(diag(R))) - qr_compact!(copy!(Ac, A), Q2, noR; positive=true) + qr_compact!(copy!(Ac, A), (Q2, noR); positive=true) @test Q == Q2 # positive and blocksize 1 - qr_compact!(copy!(Ac, A), Q, R; positive=true, blocksize=1) + qr_compact!(copy!(Ac, A), (Q, R); positive=true, blocksize=1) @test Q * R ≈ A @test Q' * Q ≈ I @test all(>=(zero(real(T))), real(diag(R))) - qr_compact!(copy!(Ac, A), Q2, noR; positive=true, blocksize=1) + qr_compact!(copy!(Ac, A), (Q2, noR); positive=true, blocksize=1) @test Q == Q2 # positive and pivoted - qr_compact!(copy!(Ac, A), Q, R; positive=true, pivoted=true) + qr_compact!(copy!(Ac, A), (Q, R); positive=true, pivoted=true) @test Q * R ≈ A @test Q' * Q ≈ I if n <= m @@ -61,7 +61,7 @@ @test real(R[i, j]) >= zero(real(T)) end end - qr_compact!(copy!(Ac, A), Q2, noR; positive=true, pivoted=true) + qr_compact!(copy!(Ac, A), (Q2, noR); positive=true, pivoted=true) @test Q == Q2 end end @@ -75,49 +75,49 @@ end R = similar(A) Q2 = similar(Q) noR = similar(A, min(m, n), 0) - qr_full!(copy!(Ac, A), Q, R) + qr_full!(copy!(Ac, A), (Q, R)) @test Q * R ≈ A @test Q' * Q ≈ I - qr_full!(copy!(Ac, A), Q2, noR) + qr_full!(copy!(Ac, A), (Q2, noR)) @test Q == Q2 # unblocked algorithm - qr_full!(copy!(Ac, A), Q, R; blocksize=1) + qr_full!(copy!(Ac, A), (Q, R); blocksize=1) @test Q * R ≈ A @test Q' * Q ≈ I - qr_full!(copy!(Ac, A), Q2, noR; blocksize=1) + qr_full!(copy!(Ac, A), (Q2, noR); blocksize=1) @test Q == Q2 if n == m - qr_full!(copy!(Q2, A), Q2, noR; blocksize=1) # in-place Q + qr_full!(copy!(Q2, A), (Q2, noR); blocksize=1) # in-place Q @test Q ≈ Q2 end # other blocking - qr_full!(copy!(Ac, A), Q, R; blocksize=8) + qr_full!(copy!(Ac, A), (Q, R); blocksize=8) @test Q * R ≈ A @test Q' * Q ≈ I - qr_full!(copy!(Ac, A), Q2, noR; blocksize=8) + qr_full!(copy!(Ac, A), (Q2, noR); blocksize=8) @test Q == Q2 # pivoted - qr_full!(copy!(Ac, A), Q, R; pivoted=true) + qr_full!(copy!(Ac, A), (Q, R); pivoted=true) @test Q * R ≈ A @test Q' * Q ≈ I - qr_full!(copy!(Ac, A), Q2, noR; pivoted=true) + qr_full!(copy!(Ac, A), (Q2, noR); pivoted=true) @test Q == Q2 # positive - qr_full!(copy!(Ac, A), Q, R; positive=true) + qr_full!(copy!(Ac, A), (Q, R); positive=true) @test Q * R ≈ A @test Q' * Q ≈ I @test all(>=(zero(real(T))), real(diag(R))) - qr_full!(copy!(Ac, A), Q2, noR; positive=true) + qr_full!(copy!(Ac, A), (Q2, noR); positive=true) @test Q == Q2 # positive and blocksize 1 - qr_full!(copy!(Ac, A), Q, R; positive=true, blocksize=1) + qr_full!(copy!(Ac, A), (Q, R); positive=true, blocksize=1) @test Q * R ≈ A @test Q' * Q ≈ I @test all(>=(zero(real(T))), real(diag(R))) - qr_full!(copy!(Ac, A), Q2, noR; positive=true, blocksize=1) + qr_full!(copy!(Ac, A), (Q2, noR); positive=true, blocksize=1) @test Q == Q2 # positive and pivoted - qr_full!(copy!(Ac, A), Q, R; positive=true, pivoted=true) + qr_full!(copy!(Ac, A), (Q, R); positive=true, pivoted=true) @test Q * R ≈ A @test Q' * Q ≈ I if n <= m @@ -129,7 +129,7 @@ end @test real(R[i, j]) >= zero(real(T)) end end - qr_full!(copy!(Ac, A), Q2, noR; positive=true, pivoted=true) + qr_full!(copy!(Ac, A), (Q2, noR); positive=true, pivoted=true) @test Q == Q2 end end diff --git a/test/runtests.jl b/test/runtests.jl index a89b5e7..d775cc3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,7 +9,7 @@ diagview(A) = view(A, diagind(A)) include("qr.jl") include("svd.jl") -include("eigh.jl") +# include("eigh.jl") @testset "MatrixAlgebraKit.jl" begin @testset "Code quality (Aqua.jl)" begin diff --git a/test/svd.jl b/test/svd.jl index 927e6b6..330cf61 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -1,30 +1,27 @@ @testset "svd_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) m = 54 for n in (37, m, 63) - for alg in (LinearAlgebra.DivideAndConquer(), LinearAlgebra.QRIteration()) + for alg in (LAPACK_DivideAndConquer(), LAPACK_QRIteration()) A = randn(T, m, n) Ac = similar(A) U = similar(A, m, min(m, n)) Vᴴ = similar(A, min(m, n), n) - S = similar(A, real(T), min(m, n)) - Sc = similar(S) + S = Diagonal(similar(A, real(T), min(m, n))) + Sc = similar(A, real(T), min(m, n)) - svd_compact!(copy!(Ac, A), U, S, Vᴴ; alg=alg) - @test U * Diagonal(S) * Vᴴ ≈ A + svd_compact!(copy!(Ac, A), (U, S, Vᴴ), alg) + @test U * S * Vᴴ ≈ A @test U' * U ≈ I @test Vᴴ * Vᴴ' ≈ I - @test all(isposdef, S) + @test isposdef(S) - U2, S2, V2ᴴ = @constinferred svd_compact!(copy!(Ac, A), U, S, Vᴴ; alg=alg) + U2, S2, V2ᴴ = @constinferred svd_compact!(copy!(Ac, A), (U, S, Vᴴ), alg) @test U2 == U @test S2 == S @test V2ᴴ == Vᴴ - svd_vals!(copy!(Ac, A), Sc; alg=alg) - @test S ≈ Sc - - S3 = @constinferred svd_vals!(copy!(Ac, A); alg=alg) - @test S3 == Sc + svd_vals!(copy!(Ac, A), Sc, alg) + @test S ≈ Diagonal(Sc) end end end @@ -32,69 +29,26 @@ end @testset "svd_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) m = 54 for n in (37, m, 63) - for alg in (LinearAlgebra.DivideAndConquer(), LinearAlgebra.QRIteration()) + for alg in (LAPACK_DivideAndConquer(), LAPACK_QRIteration()) A = randn(T, m, n) Ac = similar(A) U = similar(A, m, m) Vᴴ = similar(A, n, n) - S = similar(A, real(T), min(m, n)) + S = similar(A, real(T), m, n) Sc = similar(S) - Σ = zero(A) - svd_full!(copy!(Ac, A), U, S, Vᴴ; alg=alg) - copy!(diagview(Σ), S) - @test U * Σ * Vᴴ ≈ A + svd_full!(copy!(Ac, A), (U, S, Vᴴ), alg) + @test U * S * Vᴴ ≈ A @test U' * U ≈ I @test U * U' ≈ I @test Vᴴ * Vᴴ' ≈ I @test Vᴴ' * Vᴴ ≈ I - @test all(isposdef, S) + @test all(isposdef, view(S, diagind(S))) - U2, S2, V2ᴴ = @constinferred svd_full!(copy!(Ac, A), U, S, Vᴴ; alg=alg) + U2, S2, V2ᴴ = @constinferred svd_full!(copy!(Ac, A), (U, S, Vᴴ), alg) @test U2 == U @test S2 == S @test V2ᴴ == Vᴴ - - svd_vals!(copy!(Ac, A), Sc; alg=alg) - @test S ≈ Sc - - S3 = @constinferred svd_vals!(copy!(Ac, A); alg=alg) - @test S3 == Sc - end - end -end - -@testset "svd_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) - m = 54 - for n in (37, m, 63) - for alg in (LinearAlgebra.DivideAndConquer(), LinearAlgebra.QRIteration()) - A = randn(T, m, n) - Ac = similar(A) - U = similar(A, m, m) - Vᴴ = similar(A, n, n) - S = similar(A, real(T), min(m, n)) - Sc = similar(S) - Σ = zero(A) - - svd_full!(copy!(Ac, A), U, S, Vᴴ; alg=alg) - copy!(diagview(Σ), S) - @test U * Σ * Vᴴ ≈ A - @test U' * U ≈ I - @test U * U' ≈ I - @test Vᴴ * Vᴴ' ≈ I - @test Vᴴ' * Vᴴ ≈ I - @test all(isposdef, S) - - U2, S2, V2ᴴ = @constinferred svd_full!(copy!(Ac, A), U, S, Vᴴ; alg=alg) - @test U2 == U - @test S2 == S - @test V2ᴴ == Vᴴ - - svd_vals!(copy!(Ac, A), Sc; alg=alg) - @test S ≈ Sc - - S3 = @constinferred svd_vals!(copy!(Ac, A); alg=alg) - @test S3 == Sc end end end @@ -102,44 +56,25 @@ end @testset "svd_trunc! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) m = 54 for n in (37, m, 63) - for alg in (LinearAlgebra.DivideAndConquer(), LinearAlgebra.QRIteration()) + for alg in (LAPACK_DivideAndConquer(), LAPACK_QRIteration()) A = randn(T, m, n) Ac = similar(A) S₀ = svd_vals!(copy!(Ac, A)) minmn = min(m, n) r = minmn - 2 + trunc1 = truncrank(r) s = 1 + sqrt(eps(real(T))) + trunc2 = trunctol(s * S₀[r + 1]) - U, S, Vᴴ = @constinferred svd_trunc!(copy!(Ac, A); alg=alg, rank=r) - @test length(S) == r - @test LinearAlgebra.opnorm(A - U * Diagonal(S) * Vᴴ) ≈ S₀[r + 1] - - U, S, Vᴴ = @constinferred svd_trunc!(copy!(Ac, A); alg=alg, - atol=s * S₀[r + 1]) - @test length(S) == r - - U, S, Vᴴ = @constinferred svd_trunc!(copy!(Ac, A); alg=alg, - rtol=s * S₀[r + 1] / S₀[1]) - @test length(S) == r + U1, S1, V1ᴴ = @constinferred svd_trunc!(copy!(Ac, A), alg, trunc1) + @test length(S1.diag) == r + @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] - r1 = minmn - 6 - r2 = minmn - 4 - r3 = minmn - 2 - U, S, Vᴴ = @constinferred svd_trunc!(copy!(Ac, A); alg=alg, - rank=r1, - atol=s * S₀[r2 + 1], - rtol=s * S₀[r3 + 1] / S₀[1]) - @test length(S) == r1 - U, S, Vᴴ = @constinferred svd_trunc!(copy!(Ac, A); alg=alg, - rank=r2, - atol=s * S₀[r3 + 1], - rtol=s * S₀[r1 + 1] / S₀[1]) - @test length(S) == r1 - U, S, Vᴴ = @constinferred svd_trunc!(copy!(Ac, A); alg=alg, - rank=r3, - atol=s * S₀[r1 + 1], - rtol=s * S₀[r2 + 1] / S₀[1]) - @test length(S) == r1 + U2, S2, V2ᴴ = @constinferred svd_trunc!(copy!(Ac, A), alg, trunc2) + @test length(S2.diag) == r + @test U1 ≈ U2 + @test S1 ≈ S2 + @test V1ᴴ ≈ V2ᴴ end end end