Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make adjoint and solve work for most factorizations #40899

Merged
merged 15 commits into from
May 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,63 @@ const ⋅ = dot
const × = cross
export ⋅, ×

## convenience methods
## return only the solution of a least squares problem while avoiding promoting
## vectors to matrices.
_cut_B(x::AbstractVector, r::UnitRange) = length(x) > length(r) ? x[r] : x
_cut_B(X::AbstractMatrix, r::UnitRange) = size(X, 1) > length(r) ? X[r,:] : X

## append right hand side with zeros if necessary
_zeros(::Type{T}, b::AbstractVector, n::Integer) where {T} = zeros(T, max(length(b), n))
_zeros(::Type{T}, B::AbstractMatrix, n::Integer) where {T} = zeros(T, max(size(B, 1), n), size(B, 2))

# General fallback definition for handling under- and overdetermined system as well as square problems
# While this definition is pretty general, it does e.g. promote to common element type of lhs and rhs
# which is required by LAPACK but not SuiteSpase which allows real-complex solves in some cases. Hence,
# we restrict this method to only the LAPACK factorizations in LinearAlgebra.
# The definition is put here since it explicitly references all the Factorizion structs so it has
# to be located after all the files that define the structs.
const LAPACKFactorizations{T,S} = Union{
BunchKaufman{T,S},
Cholesky{T,S},
LQ{T,S},
LU{T,S},
QR{T,S},
QRCompactWY{T,S},
QRPivoted{T,S},
SVD{T,<:Real,S}}
function (\)(F::Union{<:LAPACKFactorizations,Adjoint{<:Any,<:LAPACKFactorizations}}, B::AbstractVecOrMat)
require_one_based_indexing(B)
m, n = size(F)
if m != size(B, 1)
throw(DimensionMismatch("arguments must have the same number of rows"))
end

TFB = typeof(oneunit(eltype(B)) / oneunit(eltype(F)))
FF = Factorization{TFB}(F)

# For wide problem we (often) compute a minimum norm solution. The solution
# is larger than the right hand side so we use size(F, 2).
BB = _zeros(TFB, B, n)

if n > size(B, 1)
# Underdetermined
copyto!(view(BB, 1:m, :), B)
else
copyto!(BB, B)
end

ldiv!(FF, BB)

# For tall problems, we compute a least squares solution so only part
# of the rhs should be returned from \ while ldiv! uses (and returns)
# the complete rhs
return _cut_B(BB, 1:n)
end
# disambiguate
(\)(F::LAPACKFactorizations{T}, B::VecOrMat{Complex{T}}) where {T<:BlasReal} =
invoke(\, Tuple{Factorization{T}, VecOrMat{Complex{T}}}, F, B)

"""
LinearAlgebra.peakflops(n::Integer=2000; parallel::Bool=false)

Expand Down
16 changes: 11 additions & 5 deletions stdlib/LinearAlgebra/src/bunchkaufman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,16 +196,14 @@ julia> S.L*S.D*S.L' - A[S.p, S.p]
bunchkaufman(A::AbstractMatrix{T}, rook::Bool=false; check::Bool = true) where {T} =
bunchkaufman!(copy_oftype(A, typeof(sqrt(oneunit(T)))), rook; check = check)

convert(::Type{BunchKaufman{T}}, B::BunchKaufman{T}) where {T} = B
convert(::Type{BunchKaufman{T}}, B::BunchKaufman) where {T} =
BunchKaufman{T}(B::BunchKaufman) where {T} =
BunchKaufman(convert(Matrix{T}, B.LD), B.ipiv, B.uplo, B.symmetric, B.rook, B.info)
convert(::Type{Factorization{T}}, B::BunchKaufman{T}) where {T} = B
convert(::Type{Factorization{T}}, B::BunchKaufman) where {T} = convert(BunchKaufman{T}, B)
Factorization{T}(B::BunchKaufman) where {T} = BunchKaufman{T}(B)

size(B::BunchKaufman) = size(getfield(B, :LD))
size(B::BunchKaufman, d::Integer) = size(getfield(B, :LD), d)
issymmetric(B::BunchKaufman) = B.symmetric
ishermitian(B::BunchKaufman) = !B.symmetric
ishermitian(B::BunchKaufman{T}) where T = T<:Real || !B.symmetric

function _ipiv2perm_bk(v::AbstractVector{T}, maxi::Integer, uplo::AbstractChar, rook::Bool) where T
require_one_based_indexing(v)
Expand Down Expand Up @@ -279,6 +277,14 @@ Base.propertynames(B::BunchKaufman, private::Bool=false) =

issuccess(B::BunchKaufman) = B.info == 0

function adjoint(B::BunchKaufman)
if ishermitian(B)
return B
else
throw(ArgumentError("adjoint not implemented for complex symmetric matrices"))
end
end

function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, B::BunchKaufman)
if issuccess(B)
summary(io, B); println(io)
Expand Down
2 changes: 2 additions & 0 deletions stdlib/LinearAlgebra/src/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,8 @@ Base.propertynames(F::CholeskyPivoted, private::Bool=false) =

issuccess(C::Union{Cholesky,CholeskyPivoted}) = C.info == 0

adjoint(C::Union{Cholesky,CholeskyPivoted}) = C

function show(io::IO, mime::MIME{Symbol("text/plain")}, C::Cholesky{<:Any,<:AbstractMatrix})
if issuccess(C)
summary(io, C); println(io)
Expand Down
8 changes: 0 additions & 8 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -488,14 +488,6 @@ rdiv!(A::AbstractMatrix{T}, transD::Transpose{<:Any,<:Diagonal{T}}) where {T} =
(/)(A::Union{StridedMatrix, AbstractTriangular}, D::Diagonal) =
rdiv!((typeof(oneunit(eltype(D))/oneunit(eltype(A)))).(A), D)

(\)(F::Factorization, D::Diagonal) =
ldiv!(F, Matrix{typeof(oneunit(eltype(D))/oneunit(eltype(F)))}(D))
\(adjF::Adjoint{<:Any,<:Factorization}, D::Diagonal) =
(F = adjF.parent; ldiv!(adjoint(F), Matrix{typeof(oneunit(eltype(D))/oneunit(eltype(F)))}(D)))
(\)(A::Union{QR,QRCompactWY,QRPivoted}, B::Diagonal) =
invoke(\, Tuple{Union{QR,QRCompactWY,QRPivoted}, AbstractVecOrMat}, A, B)


@inline function kron!(C::AbstractMatrix, A::Diagonal, B::Diagonal)
valA = A.diag; nA = length(valA)
valB = B.diag; nB = length(valB)
Expand Down
24 changes: 6 additions & 18 deletions stdlib/LinearAlgebra/src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ convert(::Type{T}, f::Factorization) where {T<:AbstractArray} = T(f)

### General promotion rules
Factorization{T}(F::Factorization{T}) where {T} = F
# This is a bit odd since the return is not a Factorization but it works well in generic code
Factorization{T}(A::Adjoint{<:Any,<:Factorization}) where {T} =
adjoint(Factorization{T}(parent(A)))
inv(F::Factorization{T}) where {T} = (n = size(F, 1); ldiv!(F, Matrix{T}(I, n, n)))

Base.hash(F::Factorization, h::UInt) = mapreduce(f -> hash(getfield(F, f)), hash, 1:nfields(F); init=h)
Expand Down Expand Up @@ -96,40 +99,25 @@ function (/)(B::VecOrMat{Complex{T}}, F::Factorization{T}) where T<:BlasReal
return copy(reinterpret(Complex{T}, x))
end

function \(F::Factorization, B::AbstractVecOrMat)
function \(F::Union{Factorization, Adjoint{<:Any,<:Factorization}}, B::AbstractVecOrMat)
require_one_based_indexing(B)
TFB = typeof(oneunit(eltype(B)) / oneunit(eltype(F)))
BB = similar(B, TFB, size(B))
copyto!(BB, B)
ldiv!(F, BB)
end
function \(adjF::Adjoint{<:Any,<:Factorization}, B::AbstractVecOrMat)
require_one_based_indexing(B)
F = adjF.parent
TFB = typeof(oneunit(eltype(B)) / oneunit(eltype(F)))
BB = similar(B, TFB, size(B))
copyto!(BB, B)
ldiv!(adjoint(F), BB)
end

function /(B::AbstractMatrix, F::Factorization)
function /(B::AbstractMatrix, F::Union{Factorization, Adjoint{<:Any,<:Factorization}})
require_one_based_indexing(B)
TFB = typeof(oneunit(eltype(B)) / oneunit(eltype(F)))
BB = similar(B, TFB, size(B))
copyto!(BB, B)
rdiv!(BB, F)
end
function /(B::AbstractMatrix, adjF::Adjoint{<:Any,<:Factorization})
require_one_based_indexing(B)
F = adjF.parent
TFB = typeof(oneunit(eltype(B)) / oneunit(eltype(F)))
BB = similar(B, TFB, size(B))
copyto!(BB, B)
rdiv!(BB, adjoint(F))
end
/(adjB::AdjointAbsVec, adjF::Adjoint{<:Any,<:Factorization}) = adjoint(adjF.parent \ adjB.parent)
/(B::TransposeAbsVec, adjF::Adjoint{<:Any,<:Factorization}) = adjoint(adjF.parent \ adjoint(B))


# support the same 3-arg idiom as in our other in-place A_*_B functions:
function ldiv!(Y::AbstractVecOrMat, A::Factorization, B::AbstractVecOrMat)
require_one_based_indexing(Y, B)
Expand Down
14 changes: 8 additions & 6 deletions stdlib/LinearAlgebra/src/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -564,28 +564,30 @@ function AbstractMatrix(F::Hessenberg)
end
end

# adjoint(Q::HessenbergQ{<:Real})

lmul!(Q::BlasHessenbergQ{T,false}, X::StridedVecOrMat{T}) where {T<:BlasFloat} =
LAPACK.ormhr!('L', 'N', 1, size(Q.factors, 1), Q.factors, Q.τ, X)
rmul!(X::StridedMatrix{T}, Q::BlasHessenbergQ{T,false}) where {T<:BlasFloat} =
rmul!(X::StridedVecOrMat{T}, Q::BlasHessenbergQ{T,false}) where {T<:BlasFloat} =
LAPACK.ormhr!('R', 'N', 1, size(Q.factors, 1), Q.factors, Q.τ, X)
lmul!(adjQ::Adjoint{<:Any,<:BlasHessenbergQ{T,false}}, X::StridedVecOrMat{T}) where {T<:BlasFloat} =
(Q = adjQ.parent; LAPACK.ormhr!('L', ifelse(T<:Real, 'T', 'C'), 1, size(Q.factors, 1), Q.factors, Q.τ, X))
rmul!(X::StridedMatrix{T}, adjQ::Adjoint{<:Any,<:BlasHessenbergQ{T,false}}) where {T<:BlasFloat} =
rmul!(X::StridedVecOrMat{T}, adjQ::Adjoint{<:Any,<:BlasHessenbergQ{T,false}}) where {T<:BlasFloat} =
(Q = adjQ.parent; LAPACK.ormhr!('R', ifelse(T<:Real, 'T', 'C'), 1, size(Q.factors, 1), Q.factors, Q.τ, X))

lmul!(Q::BlasHessenbergQ{T,true}, X::StridedVecOrMat{T}) where {T<:BlasFloat} =
LAPACK.ormtr!('L', Q.uplo, 'N', Q.factors, Q.τ, X)
rmul!(X::StridedMatrix{T}, Q::BlasHessenbergQ{T,true}) where {T<:BlasFloat} =
rmul!(X::StridedVecOrMat{T}, Q::BlasHessenbergQ{T,true}) where {T<:BlasFloat} =
LAPACK.ormtr!('R', Q.uplo, 'N', Q.factors, Q.τ, X)
lmul!(adjQ::Adjoint{<:Any,<:BlasHessenbergQ{T,true}}, X::StridedVecOrMat{T}) where {T<:BlasFloat} =
(Q = adjQ.parent; LAPACK.ormtr!('L', Q.uplo, ifelse(T<:Real, 'T', 'C'), Q.factors, Q.τ, X))
rmul!(X::StridedMatrix{T}, adjQ::Adjoint{<:Any,<:BlasHessenbergQ{T,true}}) where {T<:BlasFloat} =
rmul!(X::StridedVecOrMat{T}, adjQ::Adjoint{<:Any,<:BlasHessenbergQ{T,true}}) where {T<:BlasFloat} =
(Q = adjQ.parent; LAPACK.ormtr!('R', Q.uplo, ifelse(T<:Real, 'T', 'C'), Q.factors, Q.τ, X))

lmul!(Q::HessenbergQ{T}, X::Adjoint{T,<:StridedVecOrMat{T}}) where {T} = rmul!(X', Q')'
rmul!(X::Adjoint{T,<:StridedMatrix{T}}, Q::HessenbergQ{T}) where {T} = lmul!(Q', X')'
rmul!(X::Adjoint{T,<:StridedVecOrMat{T}}, Q::HessenbergQ{T}) where {T} = lmul!(Q', X')'
lmul!(adjQ::Adjoint{<:Any,<:HessenbergQ{T}}, X::Adjoint{T,<:StridedVecOrMat{T}}) where {T} = rmul!(X', adjQ')'
rmul!(X::Adjoint{T,<:StridedMatrix{T}}, adjQ::Adjoint{<:Any,<:HessenbergQ{T}}) where {T} = lmul!(adjQ', X')'
rmul!(X::Adjoint{T,<:StridedVecOrMat{T}}, adjQ::Adjoint{<:Any,<:HessenbergQ{T}}) where {T} = lmul!(adjQ', X')'

# multiply x by the entries of M in the upper-k triangle, which contains
# the entries of the upper-Hessenberg matrix H for k=-1
Expand Down
3 changes: 3 additions & 0 deletions stdlib/LinearAlgebra/src/ldlt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ function getproperty(F::LDLt, d::Symbol)
end
end

adjoint(F::LDLt{<:Real,<:SymTridiagonal}) = F
adjoint(F::LDLt) = LDLt(copy(adjoint(F.data)))

function show(io::IO, mime::MIME{Symbol("text/plain")}, F::LDLt)
summary(io, F); println(io)
println(io, "L factor:")
Expand Down
30 changes: 16 additions & 14 deletions stdlib/LinearAlgebra/src/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ lq_eltype(::Type{T}) where {T} = typeof(zero(T) / sqrt(abs2(one(T))))
copy(A::LQ) = LQ(copy(A.factors), copy(A.τ))

LQ{T}(A::LQ) where {T} = LQ(convert(AbstractMatrix{T}, A.factors), convert(Vector{T}, A.τ))
Factorization{T}(A::LQ{T}) where {T} = A
Factorization{T}(A::LQ) where {T} = LQ{T}(A)

AbstractMatrix(A::LQ) = A.L*A.Q
AbstractArray(A::LQ) = AbstractMatrix(A)
Matrix(A::LQ) = Array(AbstractArray(A))
Expand Down Expand Up @@ -194,7 +194,7 @@ function lmul!(A::LQ, B::StridedVecOrMat)
end
function *(A::LQ{TA}, B::StridedVecOrMat{TB}) where {TA,TB}
TAB = promote_type(TA, TB)
_cut_B(lmul!(Factorization{TAB}(A), copy_oftype(B, TAB)), 1:size(A,1))
_cut_B(lmul!(convert(Factorization{TAB}, A), copy_oftype(B, TAB)), 1:size(A,1))
end

## Multiplication by Q
Expand Down Expand Up @@ -318,17 +318,6 @@ _rightappdimmismatch(rowsorcols) =
"or (2) the number of rows of that (LQPackedQ) matrix's internal representation ",
"(the factorization's originating matrix's number of rows)")))


function (\)(A::LQ{TA},B::StridedVecOrMat{TB}) where {TA,TB}
S = promote_type(TA,TB)
m, n = size(A)
m ≤ n || throw(DimensionMismatch("LQ solver does not support overdetermined systems (more rows than columns)"))
m == size(B,1) || throw(DimensionMismatch("Both inputs should have the same number of rows"))
AA = Factorization{S}(A)
X = _zeros(S, B, n)
X[1:size(B, 1), :] = B
return ldiv!(AA, X)
end
# With a real lhs and complex rhs with the same precision, we can reinterpret
# the complex rhs as a real rhs with twice the number of columns
function (\)(F::LQ{T}, B::VecOrMat{Complex{T}}) where T<:BlasReal
Expand All @@ -342,12 +331,25 @@ function (\)(F::LQ{T}, B::VecOrMat{Complex{T}}) where T<:BlasReal
end


function ldiv!(A::LQ{T}, B::StridedVecOrMat{T}) where T
function ldiv!(A::LQ, B::StridedVecOrMat)
require_one_based_indexing(B)
m, n = size(A)
m ≤ n || throw(DimensionMismatch("LQ solver does not support overdetermined systems (more rows than columns)"))

ldiv!(LowerTriangular(A.L), view(B, 1:size(A,1), axes(B,2)))
return lmul!(adjoint(A.Q), B)
end

function ldiv!(Fadj::Adjoint{<:Any,<:LQ}, B::StridedVecOrMat)
require_one_based_indexing(B)
m, n = size(Fadj)
m >= n || throw(DimensionMismatch("solver does not support underdetermined systems (more columns than rows)"))

F = parent(Fadj)
lmul!(F.Q, B)
ldiv!(UpperTriangular(adjoint(F.L)), view(B, 1:size(F,1), axes(B,2)))
return B
end

# In LQ factorization, `Q` is expressed as the product of the adjoint of the
# reflectors. Thus, `det` has to be conjugated.
Expand Down
45 changes: 27 additions & 18 deletions stdlib/LinearAlgebra/src/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,8 @@ end
Base.propertynames(F::QRPivoted, private::Bool=false) =
(:R, :Q, :p, :P, (private ? fieldnames(typeof(F)) : ())...)

adjoint(F::Union{QR,QRPivoted,QRCompactWY}) = Adjoint(F)

abstract type AbstractQ{T} <: AbstractMatrix{T} end

inv(Q::AbstractQ) = Q'
Expand Down Expand Up @@ -939,28 +941,35 @@ function ldiv!(A::QRPivoted, B::StridedMatrix)
B
end

# convenience methods
## return only the solution of a least squares problem while avoiding promoting
## vectors to matrices.
_cut_B(x::AbstractVector, r::UnitRange) = length(x) > length(r) ? x[r] : x
_cut_B(X::AbstractMatrix, r::UnitRange) = size(X, 1) > length(r) ? X[r,:] : X

## append right hand side with zeros if necessary
_zeros(::Type{T}, b::AbstractVector, n::Integer) where {T} = zeros(T, max(length(b), n))
_zeros(::Type{T}, B::AbstractMatrix, n::Integer) where {T} = zeros(T, max(size(B, 1), n), size(B, 2))
function _apply_permutation!(F::QRPivoted, B::AbstractVecOrMat)
# Apply permutation but only to the top part of the solution vector since
# it's padded with zeros for underdetermined problems
B[1:length(F.p), :] = B[F.p, :]
return B
end
_apply_permutation!(F::Factorization, B::AbstractVecOrMat) = B

function (\)(A::Union{QR{TA},QRCompactWY{TA},QRPivoted{TA}}, B::AbstractVecOrMat{TB}) where {TA,TB}
function ldiv!(Fadj::Adjoint{<:Any,<:Union{QR,QRCompactWY,QRPivoted}}, B::AbstractVecOrMat)
require_one_based_indexing(B)
S = promote_type(TA,TB)
m, n = size(A)
m == size(B,1) || throw(DimensionMismatch("Both inputs should have the same number of rows"))
m, n = size(Fadj)

AA = Factorization{S}(A)
# We don't allow solutions overdetermined systems
if m > n
throw(DimensionMismatch("overdetermined systems are not supported"))
end
if n != size(B, 1)
throw(DimensionMismatch("inputs should have the same number of rows"))
end
F = parent(Fadj)

X = _zeros(S, B, n)
X[1:size(B, 1), :] = B
ldiv!(AA, X)
return _cut_B(X, 1:n)
B = _apply_permutation!(F, B)

# For underdetermined system, the triangular solve should only be applied to the top
# part of B that contains the rhs. For square problems, the view corresponds to B itself
ldiv!(LowerTriangular(adjoint(F.R)), view(B, 1:size(F.R, 2), :))
lmul!(F.Q, B)

return B
end

# With a real lhs and complex rhs with the same precision, we can reinterpret the complex
Expand Down
14 changes: 12 additions & 2 deletions stdlib/LinearAlgebra/src/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ function SVD{T}(U::AbstractArray, S::AbstractVector{Tr}, Vt::AbstractArray) wher
convert(AbstractArray{T}, Vt))
end

SVD{T}(F::SVD) where {T} = SVD(
convert(AbstractMatrix{T}, F.U),
convert(AbstractVector{real(T)}, F.S),
convert(AbstractMatrix{T}, F.Vt))
Factorization{T}(F::SVD) where {T} = SVD{T}(F)

# iteration for destructuring into components
Base.iterate(S::SVD) = (S.U, Val(:S))
Expand Down Expand Up @@ -235,10 +240,11 @@ svdvals(A::AbstractVector{<:BlasFloat}) = [norm(A)]
svdvals(x::Number) = abs(x)
svdvals(S::SVD{<:Any,T}) where {T} = (S.S)::Vector{T}

# SVD least squares
### SVD least squares ###
function ldiv!(A::SVD{T}, B::StridedVecOrMat) where T
m, n = size(A)
k = searchsortedlast(A.S, eps(real(T))*A.S[1], rev=true)
view(A.Vt,1:k,:)' * (view(A.S,1:k) .\ (view(A.U,:,1:k)' * B))
return mul!(view(B, 1:n, :), view(A.Vt, 1:k, :)', view(A.S, 1:k) .\ (view(A.U, :, 1:k)' * _cut_B(B, 1:m)))
end

function inv(F::SVD{T}) where T
Expand All @@ -252,6 +258,10 @@ end
size(A::SVD, dim::Integer) = dim == 1 ? size(A.U, dim) : size(A.Vt, dim)
size(A::SVD) = (size(A, 1), size(A, 2))

function adjoint(F::SVD)
return SVD(F.Vt', F.S, F.U')
end

function show(io::IO, mime::MIME{Symbol("text/plain")}, F::SVD{<:Any,<:Any,<:AbstractArray})
summary(io, F); println(io)
println(io, "U factor:")
Expand Down
Loading