Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use mul! in Diagonal*Matrix #42321

Merged
merged 16 commits into from
Oct 5, 2021
Merged
99 changes: 70 additions & 29 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,19 +198,37 @@ Base.literal_pow(::typeof(^), D::Diagonal, valp::Val) =
Diagonal(Base.literal_pow.(^, D.diag, valp)) # for speed
Base.literal_pow(::typeof(^), D::Diagonal, ::Val{-1}) = inv(D) # for disambiguation

function _muldiag_size_check(A, B)
nA = size(A, 2)
mB = size(B, 1)
@noinline throw_dimerr(::AbstractMatrix, nA, mB) = throw(DimensionMismatch("second dimension of A, $nA, does not match first dimension of B, $mB"))
@noinline throw_dimerr(::AbstractVector, nA, mB) = throw(DimensionMismatch("second dimension of D, $nA, does not match length of V, $mB"))
nA == mB || throw_dimerr(B, nA, mB)
return nothing
end
function _muldiag_size_check(C, A, B)
_muldiag_size_check(A, B)
# the output matrix should have the same size as the non-diagonal input matrix or vector
@noinline throw_dimerr(szC, szA) = throw(DimensionMismatch("output matrix has size: $szC, but should have size $szA"))
_size_check_out(C, ::Diagonal, A) = _size_check_out(C, A)
_size_check_out(C, A, ::Diagonal) = _size_check_out(C, A)
_size_check_out(C, A::Diagonal, ::Diagonal) = _size_check_out(C, A)
function _size_check_out(C, A)
szA = size(A)
szC = size(C)
szA == szC || throw_dimerr(szC, szA)
return nothing
end
_size_check_out(C, A, B)
end

function (*)(Da::Diagonal, Db::Diagonal)
nDa, mDb = size(Da, 2), size(Db, 1)
if nDa != mDb
throw(DimensionMismatch("second dimension of Da, $nDa, does not match first dimension of Db, $mDb"))
end
_muldiag_size_check(Da, Db)
return Diagonal(Da.diag .* Db.diag)
end

function (*)(D::Diagonal, V::AbstractVector)
nD = size(D, 2)
if nD != length(V)
throw(DimensionMismatch("second dimension of D, $nD, does not match length of V, $(length(V))"))
end
_muldiag_size_check(D, V)
return D.diag .* V
end

Expand All @@ -220,9 +238,9 @@ end
lmul!(D, copy_oftype(B, promote_op(*, eltype(B), eltype(D.diag))))

(*)(A::AbstractMatrix, D::Diagonal) =
rmul!(copy_similar(A, promote_op(*, eltype(A), eltype(D.diag))), D)
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A, D)
(*)(D::Diagonal, A::AbstractMatrix) =
lmul!(D, copy_similar(A, promote_op(*, eltype(A), eltype(D.diag))))
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), D, A)

function rmul!(A::AbstractMatrix, D::Diagonal)
require_one_based_indexing(A)
Expand Down Expand Up @@ -309,26 +327,49 @@ end
rmul!(A::Diagonal, B::Diagonal) = Diagonal(A.diag .*= B.diag)
lmul!(A::Diagonal, B::Diagonal) = Diagonal(B.diag .= A.diag .* B.diag)

@inline function __muldiag!(out, D::Diagonal, B, alpha, beta)
if iszero(beta)
out .= (D.diag .* B) .*ₛ alpha
else
out .= (D.diag .* B) .*ₛ alpha .+ out .* beta
end
return out
end

@inline function __muldiag!(out, A, D::Diagonal, alpha, beta)
if iszero(beta)
out .= (A .* permutedims(D.diag)) .*ₛ alpha
else
out .= (A .* permutedims(D.diag)) .*ₛ alpha .+ out .* beta
end
return out
end
# only needed for ambiguity resolution, as mul! is explicitly defined for these arguments
@inline __muldiag!(out, D1::Diagonal, D2::Diagonal, alpha, beta) =
mul!(out, D1, D2, alpha, beta)

@inline function _muldiag!(out, A, B, alpha, beta)
_muldiag_size_check(out, A, B)
__muldiag!(out, A, B, alpha, beta)
return out
end

# Get ambiguous method if try to unify AbstractVector/AbstractMatrix here using AbstractVecOrMat
@inline mul!(out::AbstractVector, A::Diagonal, in::AbstractVector, alpha::Number, beta::Number) =
out .= (A.diag .* in) .*ₛ alpha .+ out .*ₛ beta
@inline mul!(out::AbstractMatrix, A::Diagonal, in::AbstractMatrix, alpha::Number, beta::Number) =
out .= (A.diag .* in) .*ₛ alpha .+ out .*ₛ beta
@inline mul!(out::AbstractMatrix, A::Diagonal, in::Adjoint{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number) =
out .= (A.diag .* in) .*ₛ alpha .+ out .*ₛ beta
@inline mul!(out::AbstractMatrix, A::Diagonal, in::Transpose{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number) =
out .= (A.diag .* in) .*ₛ alpha .+ out .*ₛ beta

@inline mul!(out::AbstractMatrix, in::AbstractMatrix, A::Diagonal, alpha::Number, beta::Number) =
out .= (in .* permutedims(A.diag)) .*ₛ alpha .+ out .*ₛ beta
@inline mul!(out::AbstractMatrix, in::Adjoint{<:Any,<:AbstractVecOrMat}, A::Diagonal,
alpha::Number, beta::Number) =
out .= (in .* permutedims(A.diag)) .*ₛ alpha .+ out .*ₛ beta
@inline mul!(out::AbstractMatrix, in::Transpose{<:Any,<:AbstractVecOrMat}, A::Diagonal,
alpha::Number, beta::Number) =
out .= (in .* permutedims(A.diag)) .*ₛ alpha .+ out .*ₛ beta
@inline mul!(out::AbstractVector, D::Diagonal, V::AbstractVector, alpha::Number, beta::Number) =
_muldiag!(out, D, V, alpha, beta)
@inline mul!(out::AbstractMatrix, D::Diagonal, B::AbstractMatrix, alpha::Number, beta::Number) =
_muldiag!(out, D, B, alpha, beta)
@inline mul!(out::AbstractMatrix, D::Diagonal, B::Adjoint{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number) = _muldiag!(out, D, B, alpha, beta)
@inline mul!(out::AbstractMatrix, D::Diagonal, B::Transpose{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number) = _muldiag!(out, D, B, alpha, beta)

@inline mul!(out::AbstractMatrix, A::AbstractMatrix, D::Diagonal, alpha::Number, beta::Number) =
_muldiag!(out, A, D, alpha, beta)
@inline mul!(out::AbstractMatrix, A::Adjoint{<:Any,<:AbstractVecOrMat}, D::Diagonal,
alpha::Number, beta::Number) = _muldiag!(out, A, D, alpha, beta)
@inline mul!(out::AbstractMatrix, A::Transpose{<:Any,<:AbstractVecOrMat}, D::Diagonal,
alpha::Number, beta::Number) = _muldiag!(out, A, D, alpha, beta)

function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number)
mA = size(Da, 1)
Expand Down
48 changes: 46 additions & 2 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,27 @@ let D1 = Diagonal(rand(5)), D2 = Diagonal(rand(5))
@test LinearAlgebra.lmul!(adjoint(D1),copy(D2)) == adjoint(D1)*D2
end

@testset "multiplication of a Diagonal with a Matrix" begin
A = collect(reshape(1:8, 4, 2));
B = BigFloat.(A);
DL = Diagonal(collect(axes(A, 1)));
DR = Diagonal(Float16.(collect(axes(A, 2))));

@test DL * A == collect(DL) * A
@test A * DR == A * collect(DR)
@test DL * B == collect(DL) * B
@test B * DR == B * collect(DR)

A = reshape([ones(2,2), ones(2,2)*2, ones(2,2)*3, ones(2,2)*4], 2, 2)
D = Diagonal([collect(reshape(1:4, 2, 2)), collect(reshape(5:8, 2, 2))])
@test A * D == collect(A) * collect(D)
@test D * A == collect(D) * collect(A)

AS = similar(A)
mul!(AS, A, D, true, false)
@test AS == A * D
end

@testset "multiplication of QR Q-factor and Diagonal (#16615 spot test)" begin
D = Diagonal(randn(5))
Q = qr(randn(5, 5)).Q
Expand Down Expand Up @@ -686,12 +707,35 @@ end
xt = transpose(x)
A = reshape([[1 2; 3 4], zeros(Int,2,2), zeros(Int, 2, 2), [5 6; 7 8]], 2, 2)
D = Diagonal(A)
@test x'*D == x'*A == copy(x')*D == copy(x')*A
@test xt*D == xt*A == copy(xt)*D == copy(xt)*A
@test x'*D == x'*A == collect(x')*D == collect(x')*A
@test xt*D == xt*A == collect(xt)*D == collect(xt)*A
outadjxD = similar(x'*D); outtrxD = similar(xt*D);
mul!(outadjxD, x', D)
@test outadjxD == x'*D
mul!(outtrxD, xt, D)
@test outtrxD == xt*D

D1 = Diagonal([[1 2; 3 4]])
@test D1 * x' == D1 * collect(x') == collect(D1) * collect(x')
@test D1 * xt == D1 * collect(xt) == collect(D1) * collect(xt)
outD1adjx = similar(D1 * x'); outD1trx = similar(D1 * xt);
mul!(outadjxD, D1, x')
@test outadjxD == D1*x'
mul!(outtrxD, D1, xt)
@test outtrxD == D1*xt

y = [x, x]
yt = transpose(y)
@test y'*D*y == (y'*D)*y == (y'*A)*y
@test yt*D*y == (yt*D)*y == (yt*A)*y
outadjyD = similar(y'*D); outtryD = similar(yt*D);
outadjyD2 = similar(collect(y'*D)); outtryD2 = similar(collect(yt*D));
mul!(outadjyD, y', D)
mul!(outadjyD2, y', D)
@test outadjyD == outadjyD2 == y'*D
mul!(outtryD, yt, D)
mul!(outtryD2, yt, D)
@test outtryD == outtryD2 == yt*D
end

@testset "Multiplication of single element Diagonal (#36746, #40726)" begin
Expand Down