diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index b9737bf36d0c5..8ae4fe846a266 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -12,9 +12,9 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as asin, asinh, atan, atanh, axes, big, broadcast, ceil, cis, conj, convert, copy, copyto!, cos, cosh, cot, coth, csc, csch, eltype, exp, fill!, floor, getindex, hcat, getproperty, imag, inv, isapprox, isequal, isone, iszero, IndexStyle, kron, kron!, length, log, map, ndims, - oneunit, parent, power_by_squaring, print_matrix, promote_rule, real, round, sec, sech, + one, oneunit, parent, power_by_squaring, print_matrix, promote_rule, real, round, sec, sech, setindex!, show, similar, sin, sincos, sinh, size, sqrt, - strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec + strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec, zero using Base: IndexLinear, promote_eltype, promote_op, promote_typeof, @propagate_inbounds, @pure, reduce, typed_hvcat, typed_vcat, require_one_based_indexing, splat diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index b31360e233a51..75a221f8d2045 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -198,19 +198,37 @@ Base.literal_pow(::typeof(^), D::Diagonal, valp::Val) = Diagonal(Base.literal_pow.(^, D.diag, valp)) # for speed Base.literal_pow(::typeof(^), D::Diagonal, ::Val{-1}) = inv(D) # for disambiguation +function _muldiag_size_check(A, B) + nA = size(A, 2) + mB = size(B, 1) + @noinline throw_dimerr(::AbstractMatrix, nA, mB) = throw(DimensionMismatch("second dimension of A, $nA, does not match first dimension of B, $mB")) + @noinline throw_dimerr(::AbstractVector, nA, mB) = throw(DimensionMismatch("second dimension of D, $nA, does not match length of V, $mB")) + nA == mB || throw_dimerr(B, nA, mB) + return nothing +end +# the output matrix should have the same size as the non-diagonal input matrix or vector +@noinline throw_dimerr(szC, szA) = throw(DimensionMismatch("output matrix has size: $szC, but should have size $szA")) +_size_check_out(C, ::Diagonal, A) = _size_check_out(C, A) +_size_check_out(C, A, ::Diagonal) = _size_check_out(C, A) +_size_check_out(C, A::Diagonal, ::Diagonal) = _size_check_out(C, A) +function _size_check_out(C, A) + szA = size(A) + szC = size(C) + szA == szC || throw_dimerr(szC, szA) + return nothing +end +function _muldiag_size_check(C, A, B) + _muldiag_size_check(A, B) + _size_check_out(C, A, B) +end + function (*)(Da::Diagonal, Db::Diagonal) - nDa, mDb = size(Da, 2), size(Db, 1) - if nDa != mDb - throw(DimensionMismatch("second dimension of Da, $nDa, does not match first dimension of Db, $mDb")) - end + _muldiag_size_check(Da, Db) return Diagonal(Da.diag .* Db.diag) end function (*)(D::Diagonal, V::AbstractVector) - nD = size(D, 2) - if nD != length(V) - throw(DimensionMismatch("second dimension of D, $nD, does not match length of V, $(length(V))")) - end + _muldiag_size_check(D, V) return D.diag .* V end @@ -220,29 +238,12 @@ end lmul!(D, copy_oftype(B, promote_op(*, eltype(B), eltype(D.diag)))) (*)(A::AbstractMatrix, D::Diagonal) = - rmul!(copy_similar(A, promote_op(*, eltype(A), eltype(D.diag))), D) + mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A, D) (*)(D::Diagonal, A::AbstractMatrix) = - lmul!(D, copy_similar(A, promote_op(*, eltype(A), eltype(D.diag)))) + mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), D, A) -function rmul!(A::AbstractMatrix, D::Diagonal) - require_one_based_indexing(A) - nA, nD = size(A, 2), length(D.diag) - if nA != nD - throw(DimensionMismatch("second dimension of A, $nA, does not match the first of D, $nD")) - end - A .= A .* permutedims(D.diag) - return A -end - -function lmul!(D::Diagonal, B::AbstractVecOrMat) - require_one_based_indexing(B) - nB, nD = size(B, 1), length(D.diag) - if nB != nD - throw(DimensionMismatch("second dimension of D, $nD, does not match the first of B, $nB")) - end - B .= D.diag .* B - return B -end +rmul!(A::AbstractMatrix, D::Diagonal) = mul!(A, A, D) +lmul!(D::Diagonal, B::AbstractVecOrMat) = mul!(B, D, B) rmul!(A::Union{LowerTriangular,UpperTriangular}, D::Diagonal) = typeof(A)(rmul!(A.data, D)) function rmul!(A::UnitLowerTriangular, D::Diagonal) @@ -306,37 +307,66 @@ function *(D::Diagonal, transA::Transpose{<:Any,<:AbstractMatrix}) lmul!(D, At) end -rmul!(A::Diagonal, B::Diagonal) = Diagonal(A.diag .*= B.diag) -lmul!(A::Diagonal, B::Diagonal) = Diagonal(B.diag .= A.diag .* B.diag) +@inline function __muldiag!(out, D::Diagonal, B, alpha, beta) + if iszero(beta) + out .= (D.diag .* B) .*ₛ alpha + else + out .= (D.diag .* B) .*ₛ alpha .+ out .* beta + end + return out +end + +@inline function __muldiag!(out, A, D::Diagonal, alpha, beta) + if iszero(beta) + out .= (A .* permutedims(D.diag)) .*ₛ alpha + else + out .= (A .* permutedims(D.diag)) .*ₛ alpha .+ out .* beta + end + return out +end + +@inline function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, alpha, beta) + if iszero(beta) + out.diag .= (D1.diag .* D2.diag) .*ₛ alpha + else + out.diag .= (D1.diag .* D2.diag) .*ₛ alpha .+ out.diag .* beta + end + return out +end + +# only needed for ambiguity resolution, as mul! is explicitly defined for these arguments +@inline __muldiag!(out, D1::Diagonal, D2::Diagonal, alpha, beta) = + mul!(out, D1, D2, alpha, beta) + +@inline function _muldiag!(out, A, B, alpha, beta) + _muldiag_size_check(out, A, B) + __muldiag!(out, A, B, alpha, beta) + return out +end # Get ambiguous method if try to unify AbstractVector/AbstractMatrix here using AbstractVecOrMat -@inline mul!(out::AbstractVector, A::Diagonal, in::AbstractVector, alpha::Number, beta::Number) = - out .= (A.diag .* in) .*ₛ alpha .+ out .*ₛ beta -@inline mul!(out::AbstractMatrix, A::Diagonal, in::AbstractMatrix, alpha::Number, beta::Number) = - out .= (A.diag .* in) .*ₛ alpha .+ out .*ₛ beta -@inline mul!(out::AbstractMatrix, A::Diagonal, in::Adjoint{<:Any,<:AbstractVecOrMat}, - alpha::Number, beta::Number) = - out .= (A.diag .* in) .*ₛ alpha .+ out .*ₛ beta -@inline mul!(out::AbstractMatrix, A::Diagonal, in::Transpose{<:Any,<:AbstractVecOrMat}, - alpha::Number, beta::Number) = - out .= (A.diag .* in) .*ₛ alpha .+ out .*ₛ beta - -@inline mul!(out::AbstractMatrix, in::AbstractMatrix, A::Diagonal, alpha::Number, beta::Number) = - out .= (in .* permutedims(A.diag)) .*ₛ alpha .+ out .*ₛ beta -@inline mul!(out::AbstractMatrix, in::Adjoint{<:Any,<:AbstractVecOrMat}, A::Diagonal, - alpha::Number, beta::Number) = - out .= (in .* permutedims(A.diag)) .*ₛ alpha .+ out .*ₛ beta -@inline mul!(out::AbstractMatrix, in::Transpose{<:Any,<:AbstractVecOrMat}, A::Diagonal, - alpha::Number, beta::Number) = - out .= (in .* permutedims(A.diag)) .*ₛ alpha .+ out .*ₛ beta +@inline mul!(out::AbstractVector, D::Diagonal, V::AbstractVector, alpha::Number, beta::Number) = + _muldiag!(out, D, V, alpha, beta) +@inline mul!(out::AbstractMatrix, D::Diagonal, B::AbstractMatrix, alpha::Number, beta::Number) = + _muldiag!(out, D, B, alpha, beta) +@inline mul!(out::AbstractMatrix, D::Diagonal, B::Adjoint{<:Any,<:AbstractVecOrMat}, + alpha::Number, beta::Number) = _muldiag!(out, D, B, alpha, beta) +@inline mul!(out::AbstractMatrix, D::Diagonal, B::Transpose{<:Any,<:AbstractVecOrMat}, + alpha::Number, beta::Number) = _muldiag!(out, D, B, alpha, beta) + +@inline mul!(out::AbstractMatrix, A::AbstractMatrix, D::Diagonal, alpha::Number, beta::Number) = + _muldiag!(out, A, D, alpha, beta) +@inline mul!(out::AbstractMatrix, A::Adjoint{<:Any,<:AbstractVecOrMat}, D::Diagonal, + alpha::Number, beta::Number) = _muldiag!(out, A, D, alpha, beta) +@inline mul!(out::AbstractMatrix, A::Transpose{<:Any,<:AbstractVecOrMat}, D::Diagonal, + alpha::Number, beta::Number) = _muldiag!(out, A, D, alpha, beta) +@inline mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) = + _muldiag!(C, Da, Db, alpha, beta) function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) - mA = size(Da, 1) - mB = size(Db, 1) - mA == mB || throw(DimensionMismatch("A has dimensions ($mA,$mA) but B has dimensions ($mB,$mB)")) - mC, nC = size(C) - mC == nC == mA || throw(DimensionMismatch("output matrix has size: ($mC,$nC), but should have size ($mA,$mA)")) + _muldiag_size_check(C, Da, Db) require_one_based_indexing(C) + mA = size(Da, 1) da = Da.diag db = Db.diag _rmul_or_fill!(C, beta) diff --git a/stdlib/LinearAlgebra/src/generic.jl b/stdlib/LinearAlgebra/src/generic.jl index cf7e474468785..7ab9b6412cb5b 100644 --- a/stdlib/LinearAlgebra/src/generic.jl +++ b/stdlib/LinearAlgebra/src/generic.jl @@ -8,7 +8,8 @@ # inside this function. function *ₛ end Broadcast.broadcasted(::typeof(*ₛ), out, beta) = - iszero(beta::Number) ? false : broadcasted(*, out, beta) + iszero(beta::Number) ? false : + isone(beta::Number) ? broadcasted(identity, out) : broadcasted(*, out, beta) """ MulAddMul(alpha, beta) diff --git a/stdlib/LinearAlgebra/src/special.jl b/stdlib/LinearAlgebra/src/special.jl index 5c25c0993e9cc..1b0576f5cbad3 100644 --- a/stdlib/LinearAlgebra/src/special.jl +++ b/stdlib/LinearAlgebra/src/special.jl @@ -308,13 +308,16 @@ function fill!(A::Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}, x) not be filled with $x, since some of its entries are constrained.")) end -one(A::Diagonal{T}) where T = Diagonal(fill!(similar(A.diag, typeof(one(T))), one(T))) +one(D::Diagonal) = Diagonal(one.(D.diag)) one(A::Bidiagonal{T}) where T = Bidiagonal(fill!(similar(A.dv, typeof(one(T))), one(T)), fill!(similar(A.ev, typeof(one(T))), zero(one(T))), A.uplo) one(A::Tridiagonal{T}) where T = Tridiagonal(fill!(similar(A.du, typeof(one(T))), zero(one(T))), fill!(similar(A.d, typeof(one(T))), one(T)), fill!(similar(A.dl, typeof(one(T))), zero(one(T)))) one(A::SymTridiagonal{T}) where T = SymTridiagonal(fill!(similar(A.dv, typeof(one(T))), one(T)), fill!(similar(A.ev, typeof(one(T))), zero(one(T)))) # equals and approx equals methods for structured matrices # SymTridiagonal == Tridiagonal is already defined in tridiag.jl +zero(D::Diagonal) = Diagonal(zero.(D.diag)) +oneunit(D::Diagonal) = Diagonal(oneunit.(D.diag)) + # SymTridiagonal and Bidiagonal have the same field names ==(A::Diagonal, B::Union{SymTridiagonal, Bidiagonal}) = iszero(B.ev) && A.diag == B.dv ==(B::Bidiagonal, A::Diagonal) = A == B diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index d782fd358bad5..3f78549bbd54a 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -578,6 +578,41 @@ let D1 = Diagonal(rand(5)), D2 = Diagonal(rand(5)) @test LinearAlgebra.lmul!(adjoint(D1),copy(D2)) == adjoint(D1)*D2 end +@testset "multiplication of a Diagonal with a Matrix" begin + A = collect(reshape(1:8, 4, 2)); + B = BigFloat.(A); + DL = Diagonal(collect(axes(A, 1))); + DR = Diagonal(Float16.(collect(axes(A, 2)))); + + @test DL * A == collect(DL) * A + @test A * DR == A * collect(DR) + @test DL * B == collect(DL) * B + @test B * DR == B * collect(DR) + + A = reshape([ones(2,2), ones(2,2)*2, ones(2,2)*3, ones(2,2)*4], 2, 2) + Ac = collect(A) + D = Diagonal([collect(reshape(1:4, 2, 2)), collect(reshape(5:8, 2, 2))]) + Dc = collect(D) + @test A * D == Ac * Dc + @test D * A == Dc * Ac + @test D * D == Dc * Dc + + AS = similar(A) + mul!(AS, A, D, true, false) + @test AS == A * D + + D2 = similar(D) + mul!(D2, D, D) + @test D2 == D * D + + D2[diagind(D2)] .= D[diagind(D)] + lmul!(D, D2) + @test D2 == D * D + D2[diagind(D2)] .= D[diagind(D)] + rmul!(D2, D) + @test D2 == D * D +end + @testset "multiplication of QR Q-factor and Diagonal (#16615 spot test)" begin D = Diagonal(randn(5)) Q = qr(randn(5, 5)).Q @@ -686,12 +721,35 @@ end xt = transpose(x) A = reshape([[1 2; 3 4], zeros(Int,2,2), zeros(Int, 2, 2), [5 6; 7 8]], 2, 2) D = Diagonal(A) - @test x'*D == x'*A == copy(x')*D == copy(x')*A - @test xt*D == xt*A == copy(xt)*D == copy(xt)*A + @test x'*D == x'*A == collect(x')*D == collect(x')*A + @test xt*D == xt*A == collect(xt)*D == collect(xt)*A + outadjxD = similar(x'*D); outtrxD = similar(xt*D); + mul!(outadjxD, x', D) + @test outadjxD == x'*D + mul!(outtrxD, xt, D) + @test outtrxD == xt*D + + D1 = Diagonal([[1 2; 3 4]]) + @test D1 * x' == D1 * collect(x') == collect(D1) * collect(x') + @test D1 * xt == D1 * collect(xt) == collect(D1) * collect(xt) + outD1adjx = similar(D1 * x'); outD1trx = similar(D1 * xt); + mul!(outadjxD, D1, x') + @test outadjxD == D1*x' + mul!(outtrxD, D1, xt) + @test outtrxD == D1*xt + y = [x, x] yt = transpose(y) @test y'*D*y == (y'*D)*y == (y'*A)*y @test yt*D*y == (yt*D)*y == (yt*A)*y + outadjyD = similar(y'*D); outtryD = similar(yt*D); + outadjyD2 = similar(collect(y'*D)); outtryD2 = similar(collect(yt*D)); + mul!(outadjyD, y', D) + mul!(outadjyD2, y', D) + @test outadjyD == outadjyD2 == y'*D + mul!(outtryD, yt, D) + mul!(outtryD2, yt, D) + @test outtryD == outtryD2 == yt*D end @testset "Multiplication of single element Diagonal (#36746, #40726)" begin @@ -826,4 +884,22 @@ end @test \(x, B) == /(B, x) end +@testset "zero and one" begin + D1 = Diagonal(rand(3)) + @test D1 + zero(D1) == D1 + @test D1 * one(D1) == D1 + @test D1 * oneunit(D1) == D1 + @test oneunit(D1) isa typeof(D1) + D2 = Diagonal([collect(reshape(1:4, 2, 2)), collect(reshape(5:8, 2, 2))]) + @test D2 + zero(D2) == D2 + @test D2 * one(D2) == D2 + @test D2 * oneunit(D2) == D2 + @test oneunit(D2) isa typeof(D2) + D3 = Diagonal([D2, D2]); + @test D3 + zero(D3) == D3 + @test D3 * one(D3) == D3 + @test D3 * oneunit(D3) == D3 + @test oneunit(D3) isa typeof(D3) +end + end # module TestDiagonal