diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index aabfb3e8ba114..17ff232f5b262 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -686,16 +686,33 @@ for Tri in (:UpperTriangular, :LowerTriangular) end @inline function kron!(C::AbstractMatrix, A::Diagonal, B::Diagonal) - valA = A.diag; nA = length(valA) - valB = B.diag; nB = length(valB) + valA = A.diag; mA, nA = size(A) + valB = B.diag; mB, nB = size(B) nC = checksquare(C) @boundscheck nC == nA*nB || throw(DimensionMismatch(lazy"expect C to be a $(nA*nB)x$(nA*nB) matrix, got size $(nC)x$(nC)")) - isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1])) + zerofilled = false + if !(isempty(A) || isempty(B)) + z = A[1,1] * B[1,1] + if haszero(typeof(z)) + # in this case, the zero is unique + fill!(C, zero(z)) + zerofilled = true + end + end @inbounds for i = 1:nA, j = 1:nB idx = (i-1)*nB+j C[idx, idx] = valA[i] * valB[j] end + if !zerofilled + for j in 1:nA, i in 1:mA + Δrow, Δcol = (i-1)*mB, (j-1)*nB + for k in 1:nB, l in 1:mB + i == j && k == l && continue + C[Δrow + l, Δcol + k] = A[i,j] * B[l,k] + end + end + end return C end @@ -722,7 +739,15 @@ end (mC, nC) = size(C) @boundscheck (mC, nC) == (mA * mB, nA * nB) || throw(DimensionMismatch(lazy"expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)")) - isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1])) + zerofilled = false + if !(isempty(A) || isempty(B)) + z = A[1,1] * B[1,1] + if haszero(typeof(z)) + # in this case, the zero is unique + fill!(C, zero(z)) + zerofilled = true + end + end m = 1 @inbounds for j = 1:nA A_jj = A[j,j] @@ -733,6 +758,18 @@ end end m += (nA - 1) * mB end + if !zerofilled + # populate the zero elements + for i in 1:mA + i == j && continue + A_ij = A[i, j] + Δrow, Δcol = (i-1)*mB, (j-1)*nB + for k in 1:nB, l in 1:nA + B_lk = B[l, k] + C[Δrow + l, Δcol + k] = A_ij * B_lk + end + end + end m += mB end return C @@ -745,17 +782,36 @@ end (mC, nC) = size(C) @boundscheck (mC, nC) == (mA * mB, nA * nB) || throw(DimensionMismatch(lazy"expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)")) - isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1])) + zerofilled = false + if !(isempty(A) || isempty(B)) + z = A[1,1] * B[1,1] + if haszero(typeof(z)) + # in this case, the zero is unique + fill!(C, zero(z)) + zerofilled = true + end + end m = 1 @inbounds for j = 1:nA for l = 1:mB Bll = B[l,l] - for k = 1:mA - C[m] = A[k,j] * Bll + for i = 1:mA + C[m] = A[i,j] * Bll m += nB end m += 1 end + if !zerofilled + for i in 1:mA + A_ij = A[i, j] + Δrow, Δcol = (i-1)*mB, (j-1)*nB + for k in 1:nB, l in 1:mB + l == k && continue + B_lk = B[l, k] + C[Δrow + l, Δcol + k] = A_ij * B_lk + end + end + end m -= nB end return C diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index f7a9ccb705de9..8b56ee15e56e3 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -1391,4 +1391,14 @@ end @test checkbounds(Bool, D, diagind(D, IndexCartesian())) end +@testset "zeros in kron with block matrices" begin + D = Diagonal(1:2) + B = reshape([ones(2,2), ones(3,2), ones(2,3), ones(3,3)], 2, 2) + @test kron(D, B) == kron(Array(D), B) + @test kron(B, D) == kron(B, Array(D)) + D2 = Diagonal([ones(2,2), ones(3,3)]) + @test kron(D, D2) == kron(D, Array{eltype(D2)}(D2)) + @test kron(D2, D) == kron(Array{eltype(D2)}(D2), D) +end + end # module TestDiagonal