Skip to content

Commit

Permalink
Fix zero elements for block-matrix kron involving Diagonal (#55941)
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Oct 15, 2024
1 parent 66b620f commit 4e03986
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 7 deletions.
70 changes: 63 additions & 7 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -634,16 +634,33 @@ for Tri in (:UpperTriangular, :LowerTriangular)
end

@inline function kron!(C::AbstractMatrix, A::Diagonal, B::Diagonal)
valA = A.diag; nA = length(valA)
valB = B.diag; nB = length(valB)
valA = A.diag; mA, nA = size(A)
valB = B.diag; mB, nB = size(B)
nC = checksquare(C)
@boundscheck nC == nA*nB ||
throw(DimensionMismatch(lazy"expect C to be a $(nA*nB)x$(nA*nB) matrix, got size $(nC)x$(nC)"))
isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1]))
zerofilled = false
if !(isempty(A) || isempty(B))
z = A[1,1] * B[1,1]
if haszero(typeof(z))
# in this case, the zero is unique
fill!(C, zero(z))
zerofilled = true
end
end
@inbounds for i = 1:nA, j = 1:nB
idx = (i-1)*nB+j
C[idx, idx] = valA[i] * valB[j]
end
if !zerofilled
for j in 1:nA, i in 1:mA
Δrow, Δcol = (i-1)*mB, (j-1)*nB
for k in 1:nB, l in 1:mB
i == j && k == l && continue
C[Δrow + l, Δcol + k] = A[i,j] * B[l,k]
end
end
end
return C
end

Expand All @@ -670,7 +687,15 @@ end
(mC, nC) = size(C)
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
throw(DimensionMismatch(lazy"expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)"))
isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1]))
zerofilled = false
if !(isempty(A) || isempty(B))
z = A[1,1] * B[1,1]
if haszero(typeof(z))
# in this case, the zero is unique
fill!(C, zero(z))
zerofilled = true
end
end
m = 1
@inbounds for j = 1:nA
A_jj = A[j,j]
Expand All @@ -681,6 +706,18 @@ end
end
m += (nA - 1) * mB
end
if !zerofilled
# populate the zero elements
for i in 1:mA
i == j && continue
A_ij = A[i, j]
Δrow, Δcol = (i-1)*mB, (j-1)*nB
for k in 1:nB, l in 1:nA
B_lk = B[l, k]
C[Δrow + l, Δcol + k] = A_ij * B_lk
end
end
end
m += mB
end
return C
Expand All @@ -693,17 +730,36 @@ end
(mC, nC) = size(C)
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
throw(DimensionMismatch(lazy"expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)"))
isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1]))
zerofilled = false
if !(isempty(A) || isempty(B))
z = A[1,1] * B[1,1]
if haszero(typeof(z))
# in this case, the zero is unique
fill!(C, zero(z))
zerofilled = true
end
end
m = 1
@inbounds for j = 1:nA
for l = 1:mB
Bll = B[l,l]
for k = 1:mA
C[m] = A[k,j] * Bll
for i = 1:mA
C[m] = A[i,j] * Bll
m += nB
end
m += 1
end
if !zerofilled
for i in 1:mA
A_ij = A[i, j]
Δrow, Δcol = (i-1)*mB, (j-1)*nB
for k in 1:nB, l in 1:mB
l == k && continue
B_lk = B[l, k]
C[Δrow + l, Δcol + k] = A_ij * B_lk
end
end
end
m -= nB
end
return C
Expand Down
10 changes: 10 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1323,4 +1323,14 @@ end
@test checkbounds(Bool, D, diagind(D, IndexCartesian()))
end

@testset "zeros in kron with block matrices" begin
D = Diagonal(1:2)
B = reshape([ones(2,2), ones(3,2), ones(2,3), ones(3,3)], 2, 2)
@test kron(D, B) == kron(Array(D), B)
@test kron(B, D) == kron(B, Array(D))
D2 = Diagonal([ones(2,2), ones(3,3)])
@test kron(D, D2) == Diagonal([diag(D2); 2diag(D2)])
@test kron(D2, D) == Diagonal([ones(2,2), fill(2.0,2,2), ones(3,3), fill(2.0,3,3)])
end

end # module TestDiagonal

0 comments on commit 4e03986

Please sign in to comment.