diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 4646520ecd1ab..e47b0248fbe53 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -320,15 +320,6 @@ end rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D) lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B) -function (*)(A::AdjOrTransAbsMat, D::Diagonal) - Ac = copy_similar(A, promote_op(*, eltype(A), eltype(D.diag))) - rmul!(Ac, D) -end -function (*)(D::Diagonal, A::AdjOrTransAbsMat) - Ac = copy_similar(A, promote_op(*, eltype(A), eltype(D.diag))) - lmul!(D, Ac) -end - function __muldiag!(out, D::Diagonal, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} require_one_based_indexing(out, B) alpha, beta = _add.alpha, _add.beta diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index caefc7889a8d6..72c99f742a2f4 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -18,6 +18,9 @@ using .Main.InfiniteArrays isdefined(Main, :FillArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "FillArrays.jl")) using .Main.FillArrays +isdefined(Main, :SizedArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "SizedArrays.jl")) +using .Main.SizedArrays + const n=12 # Size of matrix problem to test Random.seed!(1) @@ -778,6 +781,11 @@ end D = Diagonal(fill(M, n)) @test D == Matrix{eltype(D)}(D) end + + S = SizedArray{(2,3)}(reshape([1:6;],2,3)) + D = Diagonal(fill(S,3)) + @test D * fill(S,2,3)' == fill(S * S', 3, 2) + @test fill(S,3,2)' * D == fill(S' * S, 2, 3) end @testset "Eigensystem for block diagonal (issue #30681)" begin diff --git a/test/testhelpers/SizedArrays.jl b/test/testhelpers/SizedArrays.jl index dfcc5b79f1387..fc2862d844b3f 100644 --- a/test/testhelpers/SizedArrays.jl +++ b/test/testhelpers/SizedArrays.jl @@ -9,6 +9,8 @@ module SizedArrays import Base: +, *, == +using LinearAlgebra + export SizedArray struct SizedArray{SZ,T,N,A<:AbstractArray} <: AbstractArray{T,N} @@ -31,9 +33,16 @@ Base.getindex(A::SizedArray, i...) = getindex(A.data, i...) Base.zero(::Type{T}) where T <: SizedArray = SizedArray{size(T)}(zeros(eltype(T), size(T))) +(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = SizedArray{SZ}(S1.data + S2.data) ==(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = S1.data == S2.data -function *(S1::SizedArray, S2::SizedArray) + +const SizedArrayLike = Union{SizedArray, Transpose{<:Any, <:SizedArray}, Adjoint{<:Any, <:SizedArray}} + +_data(S::SizedArray) = S.data +_data(T::Transpose{<:Any, <:SizedArray}) = transpose(_data(parent(T))) +_data(T::Adjoint{<:Any, <:SizedArray}) = adjoint(_data(parent(T))) + +function *(S1::SizedArrayLike, S2::SizedArrayLike) 0 < ndims(S1) < 3 && 0 < ndims(S2) < 3 && size(S1, 2) == size(S2, 1) || throw(ArgumentError("size mismatch!")) - data = S1.data * S2.data + data = _data(S1) * _data(S2) SZ = ndims(data) == 1 ? (size(S1, 1), ) : (size(S1, 1), size(S2, 2)) SizedArray{SZ}(data) end