Skip to content

Commit

Permalink
Use mul! in Diagonal*Matrix (#42321)
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub authored Oct 5, 2021
1 parent 7fbbaef commit a8db751
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 61 deletions.
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as
asin, asinh, atan, atanh, axes, big, broadcast, ceil, cis, conj, convert, copy, copyto!, cos,
cosh, cot, coth, csc, csch, eltype, exp, fill!, floor, getindex, hcat,
getproperty, imag, inv, isapprox, isequal, isone, iszero, IndexStyle, kron, kron!, length, log, map, ndims,
oneunit, parent, power_by_squaring, print_matrix, promote_rule, real, round, sec, sech,
one, oneunit, parent, power_by_squaring, print_matrix, promote_rule, real, round, sec, sech,
setindex!, show, similar, sin, sincos, sinh, size, sqrt,
strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec
strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec, zero
using Base: IndexLinear, promote_eltype, promote_op, promote_typeof,
@propagate_inbounds, @pure, reduce, typed_hvcat, typed_vcat, require_one_based_indexing,
splat
Expand Down
140 changes: 85 additions & 55 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,19 +198,37 @@ Base.literal_pow(::typeof(^), D::Diagonal, valp::Val) =
Diagonal(Base.literal_pow.(^, D.diag, valp)) # for speed
Base.literal_pow(::typeof(^), D::Diagonal, ::Val{-1}) = inv(D) # for disambiguation

function _muldiag_size_check(A, B)
nA = size(A, 2)
mB = size(B, 1)
@noinline throw_dimerr(::AbstractMatrix, nA, mB) = throw(DimensionMismatch("second dimension of A, $nA, does not match first dimension of B, $mB"))
@noinline throw_dimerr(::AbstractVector, nA, mB) = throw(DimensionMismatch("second dimension of D, $nA, does not match length of V, $mB"))
nA == mB || throw_dimerr(B, nA, mB)
return nothing
end
# the output matrix should have the same size as the non-diagonal input matrix or vector
@noinline throw_dimerr(szC, szA) = throw(DimensionMismatch("output matrix has size: $szC, but should have size $szA"))
_size_check_out(C, ::Diagonal, A) = _size_check_out(C, A)
_size_check_out(C, A, ::Diagonal) = _size_check_out(C, A)
_size_check_out(C, A::Diagonal, ::Diagonal) = _size_check_out(C, A)
function _size_check_out(C, A)
szA = size(A)
szC = size(C)
szA == szC || throw_dimerr(szC, szA)
return nothing
end
function _muldiag_size_check(C, A, B)
_muldiag_size_check(A, B)
_size_check_out(C, A, B)
end

function (*)(Da::Diagonal, Db::Diagonal)
nDa, mDb = size(Da, 2), size(Db, 1)
if nDa != mDb
throw(DimensionMismatch("second dimension of Da, $nDa, does not match first dimension of Db, $mDb"))
end
_muldiag_size_check(Da, Db)
return Diagonal(Da.diag .* Db.diag)
end

function (*)(D::Diagonal, V::AbstractVector)
nD = size(D, 2)
if nD != length(V)
throw(DimensionMismatch("second dimension of D, $nD, does not match length of V, $(length(V))"))
end
_muldiag_size_check(D, V)
return D.diag .* V
end

Expand All @@ -220,29 +238,12 @@ end
lmul!(D, copy_oftype(B, promote_op(*, eltype(B), eltype(D.diag))))

(*)(A::AbstractMatrix, D::Diagonal) =
rmul!(copy_similar(A, promote_op(*, eltype(A), eltype(D.diag))), D)
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A, D)
(*)(D::Diagonal, A::AbstractMatrix) =
lmul!(D, copy_similar(A, promote_op(*, eltype(A), eltype(D.diag))))
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), D, A)

function rmul!(A::AbstractMatrix, D::Diagonal)
require_one_based_indexing(A)
nA, nD = size(A, 2), length(D.diag)
if nA != nD
throw(DimensionMismatch("second dimension of A, $nA, does not match the first of D, $nD"))
end
A .= A .* permutedims(D.diag)
return A
end

function lmul!(D::Diagonal, B::AbstractVecOrMat)
require_one_based_indexing(B)
nB, nD = size(B, 1), length(D.diag)
if nB != nD
throw(DimensionMismatch("second dimension of D, $nD, does not match the first of B, $nB"))
end
B .= D.diag .* B
return B
end
rmul!(A::AbstractMatrix, D::Diagonal) = mul!(A, A, D)
lmul!(D::Diagonal, B::AbstractVecOrMat) = mul!(B, D, B)

rmul!(A::Union{LowerTriangular,UpperTriangular}, D::Diagonal) = typeof(A)(rmul!(A.data, D))
function rmul!(A::UnitLowerTriangular, D::Diagonal)
Expand Down Expand Up @@ -306,37 +307,66 @@ function *(D::Diagonal, transA::Transpose{<:Any,<:AbstractMatrix})
lmul!(D, At)
end

rmul!(A::Diagonal, B::Diagonal) = Diagonal(A.diag .*= B.diag)
lmul!(A::Diagonal, B::Diagonal) = Diagonal(B.diag .= A.diag .* B.diag)
@inline function __muldiag!(out, D::Diagonal, B, alpha, beta)
if iszero(beta)
out .= (D.diag .* B) .*ₛ alpha
else
out .= (D.diag .* B) .*ₛ alpha .+ out .* beta
end
return out
end

@inline function __muldiag!(out, A, D::Diagonal, alpha, beta)
if iszero(beta)
out .= (A .* permutedims(D.diag)) .*ₛ alpha
else
out .= (A .* permutedims(D.diag)) .*ₛ alpha .+ out .* beta
end
return out
end

@inline function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, alpha, beta)
if iszero(beta)
out.diag .= (D1.diag .* D2.diag) .*ₛ alpha
else
out.diag .= (D1.diag .* D2.diag) .*ₛ alpha .+ out.diag .* beta
end
return out
end

# only needed for ambiguity resolution, as mul! is explicitly defined for these arguments
@inline __muldiag!(out, D1::Diagonal, D2::Diagonal, alpha, beta) =
mul!(out, D1, D2, alpha, beta)

@inline function _muldiag!(out, A, B, alpha, beta)
_muldiag_size_check(out, A, B)
__muldiag!(out, A, B, alpha, beta)
return out
end

# Get ambiguous method if try to unify AbstractVector/AbstractMatrix here using AbstractVecOrMat
@inline mul!(out::AbstractVector, A::Diagonal, in::AbstractVector, alpha::Number, beta::Number) =
out .= (A.diag .* in) .*ₛ alpha .+ out .*ₛ beta
@inline mul!(out::AbstractMatrix, A::Diagonal, in::AbstractMatrix, alpha::Number, beta::Number) =
out .= (A.diag .* in) .*ₛ alpha .+ out .*ₛ beta
@inline mul!(out::AbstractMatrix, A::Diagonal, in::Adjoint{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number) =
out .= (A.diag .* in) .*ₛ alpha .+ out .*ₛ beta
@inline mul!(out::AbstractMatrix, A::Diagonal, in::Transpose{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number) =
out .= (A.diag .* in) .*ₛ alpha .+ out .*ₛ beta

@inline mul!(out::AbstractMatrix, in::AbstractMatrix, A::Diagonal, alpha::Number, beta::Number) =
out .= (in .* permutedims(A.diag)) .*ₛ alpha .+ out .*ₛ beta
@inline mul!(out::AbstractMatrix, in::Adjoint{<:Any,<:AbstractVecOrMat}, A::Diagonal,
alpha::Number, beta::Number) =
out .= (in .* permutedims(A.diag)) .*ₛ alpha .+ out .*ₛ beta
@inline mul!(out::AbstractMatrix, in::Transpose{<:Any,<:AbstractVecOrMat}, A::Diagonal,
alpha::Number, beta::Number) =
out .= (in .* permutedims(A.diag)) .*ₛ alpha .+ out .*ₛ beta
@inline mul!(out::AbstractVector, D::Diagonal, V::AbstractVector, alpha::Number, beta::Number) =
_muldiag!(out, D, V, alpha, beta)
@inline mul!(out::AbstractMatrix, D::Diagonal, B::AbstractMatrix, alpha::Number, beta::Number) =
_muldiag!(out, D, B, alpha, beta)
@inline mul!(out::AbstractMatrix, D::Diagonal, B::Adjoint{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number) = _muldiag!(out, D, B, alpha, beta)
@inline mul!(out::AbstractMatrix, D::Diagonal, B::Transpose{<:Any,<:AbstractVecOrMat},
alpha::Number, beta::Number) = _muldiag!(out, D, B, alpha, beta)

@inline mul!(out::AbstractMatrix, A::AbstractMatrix, D::Diagonal, alpha::Number, beta::Number) =
_muldiag!(out, A, D, alpha, beta)
@inline mul!(out::AbstractMatrix, A::Adjoint{<:Any,<:AbstractVecOrMat}, D::Diagonal,
alpha::Number, beta::Number) = _muldiag!(out, A, D, alpha, beta)
@inline mul!(out::AbstractMatrix, A::Transpose{<:Any,<:AbstractVecOrMat}, D::Diagonal,
alpha::Number, beta::Number) = _muldiag!(out, A, D, alpha, beta)
@inline mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) =
_muldiag!(C, Da, Db, alpha, beta)

function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number)
mA = size(Da, 1)
mB = size(Db, 1)
mA == mB || throw(DimensionMismatch("A has dimensions ($mA,$mA) but B has dimensions ($mB,$mB)"))
mC, nC = size(C)
mC == nC == mA || throw(DimensionMismatch("output matrix has size: ($mC,$nC), but should have size ($mA,$mA)"))
_muldiag_size_check(C, Da, Db)
require_one_based_indexing(C)
mA = size(Da, 1)
da = Da.diag
db = Db.diag
_rmul_or_fill!(C, beta)
Expand Down
3 changes: 2 additions & 1 deletion stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
# inside this function.
function *end
Broadcast.broadcasted(::typeof(*ₛ), out, beta) =
iszero(beta::Number) ? false : broadcasted(*, out, beta)
iszero(beta::Number) ? false :
isone(beta::Number) ? broadcasted(identity, out) : broadcasted(*, out, beta)

"""
MulAddMul(alpha, beta)
Expand Down
5 changes: 4 additions & 1 deletion stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,16 @@ function fill!(A::Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}, x)
not be filled with $x, since some of its entries are constrained."))
end

one(A::Diagonal{T}) where T = Diagonal(fill!(similar(A.diag, typeof(one(T))), one(T)))
one(D::Diagonal) = Diagonal(one.(D.diag))
one(A::Bidiagonal{T}) where T = Bidiagonal(fill!(similar(A.dv, typeof(one(T))), one(T)), fill!(similar(A.ev, typeof(one(T))), zero(one(T))), A.uplo)
one(A::Tridiagonal{T}) where T = Tridiagonal(fill!(similar(A.du, typeof(one(T))), zero(one(T))), fill!(similar(A.d, typeof(one(T))), one(T)), fill!(similar(A.dl, typeof(one(T))), zero(one(T))))
one(A::SymTridiagonal{T}) where T = SymTridiagonal(fill!(similar(A.dv, typeof(one(T))), one(T)), fill!(similar(A.ev, typeof(one(T))), zero(one(T))))
# equals and approx equals methods for structured matrices
# SymTridiagonal == Tridiagonal is already defined in tridiag.jl

zero(D::Diagonal) = Diagonal(zero.(D.diag))
oneunit(D::Diagonal) = Diagonal(oneunit.(D.diag))

# SymTridiagonal and Bidiagonal have the same field names
==(A::Diagonal, B::Union{SymTridiagonal, Bidiagonal}) = iszero(B.ev) && A.diag == B.dv
==(B::Bidiagonal, A::Diagonal) = A == B
Expand Down
80 changes: 78 additions & 2 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,41 @@ let D1 = Diagonal(rand(5)), D2 = Diagonal(rand(5))
@test LinearAlgebra.lmul!(adjoint(D1),copy(D2)) == adjoint(D1)*D2
end

@testset "multiplication of a Diagonal with a Matrix" begin
A = collect(reshape(1:8, 4, 2));
B = BigFloat.(A);
DL = Diagonal(collect(axes(A, 1)));
DR = Diagonal(Float16.(collect(axes(A, 2))));

@test DL * A == collect(DL) * A
@test A * DR == A * collect(DR)
@test DL * B == collect(DL) * B
@test B * DR == B * collect(DR)

A = reshape([ones(2,2), ones(2,2)*2, ones(2,2)*3, ones(2,2)*4], 2, 2)
Ac = collect(A)
D = Diagonal([collect(reshape(1:4, 2, 2)), collect(reshape(5:8, 2, 2))])
Dc = collect(D)
@test A * D == Ac * Dc
@test D * A == Dc * Ac
@test D * D == Dc * Dc

AS = similar(A)
mul!(AS, A, D, true, false)
@test AS == A * D

D2 = similar(D)
mul!(D2, D, D)
@test D2 == D * D

D2[diagind(D2)] .= D[diagind(D)]
lmul!(D, D2)
@test D2 == D * D
D2[diagind(D2)] .= D[diagind(D)]
rmul!(D2, D)
@test D2 == D * D
end

@testset "multiplication of QR Q-factor and Diagonal (#16615 spot test)" begin
D = Diagonal(randn(5))
Q = qr(randn(5, 5)).Q
Expand Down Expand Up @@ -686,12 +721,35 @@ end
xt = transpose(x)
A = reshape([[1 2; 3 4], zeros(Int,2,2), zeros(Int, 2, 2), [5 6; 7 8]], 2, 2)
D = Diagonal(A)
@test x'*D == x'*A == copy(x')*D == copy(x')*A
@test xt*D == xt*A == copy(xt)*D == copy(xt)*A
@test x'*D == x'*A == collect(x')*D == collect(x')*A
@test xt*D == xt*A == collect(xt)*D == collect(xt)*A
outadjxD = similar(x'*D); outtrxD = similar(xt*D);
mul!(outadjxD, x', D)
@test outadjxD == x'*D
mul!(outtrxD, xt, D)
@test outtrxD == xt*D

D1 = Diagonal([[1 2; 3 4]])
@test D1 * x' == D1 * collect(x') == collect(D1) * collect(x')
@test D1 * xt == D1 * collect(xt) == collect(D1) * collect(xt)
outD1adjx = similar(D1 * x'); outD1trx = similar(D1 * xt);
mul!(outadjxD, D1, x')
@test outadjxD == D1*x'
mul!(outtrxD, D1, xt)
@test outtrxD == D1*xt

y = [x, x]
yt = transpose(y)
@test y'*D*y == (y'*D)*y == (y'*A)*y
@test yt*D*y == (yt*D)*y == (yt*A)*y
outadjyD = similar(y'*D); outtryD = similar(yt*D);
outadjyD2 = similar(collect(y'*D)); outtryD2 = similar(collect(yt*D));
mul!(outadjyD, y', D)
mul!(outadjyD2, y', D)
@test outadjyD == outadjyD2 == y'*D
mul!(outtryD, yt, D)
mul!(outtryD2, yt, D)
@test outtryD == outtryD2 == yt*D
end

@testset "Multiplication of single element Diagonal (#36746, #40726)" begin
Expand Down Expand Up @@ -826,4 +884,22 @@ end
@test \(x, B) == /(B, x)
end

@testset "zero and one" begin
D1 = Diagonal(rand(3))
@test D1 + zero(D1) == D1
@test D1 * one(D1) == D1
@test D1 * oneunit(D1) == D1
@test oneunit(D1) isa typeof(D1)
D2 = Diagonal([collect(reshape(1:4, 2, 2)), collect(reshape(5:8, 2, 2))])
@test D2 + zero(D2) == D2
@test D2 * one(D2) == D2
@test D2 * oneunit(D2) == D2
@test oneunit(D2) isa typeof(D2)
D3 = Diagonal([D2, D2]);
@test D3 + zero(D3) == D3
@test D3 * one(D3) == D3
@test D3 * oneunit(D3) == D3
@test oneunit(D3) isa typeof(D3)
end

end # module TestDiagonal

0 comments on commit a8db751

Please sign in to comment.