Skip to content

Commit a20e547

Browse files
authored
Fix kron with Diagonal (#40509)
1 parent 36a048c commit a20e547

File tree

2 files changed

+14
-17
lines changed

2 files changed

+14
-17
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -472,14 +472,13 @@ rdiv!(A::AbstractMatrix{T}, transD::Transpose{<:Any,<:Diagonal{T}}) where {T} =
472472
invoke(\, Tuple{Union{QR,QRCompactWY,QRPivoted}, AbstractVecOrMat}, A, B)
473473

474474

475-
@inline function kron!(C::AbstractMatrix{T}, A::Diagonal, B::Diagonal) where T
476-
fill!(C, zero(T))
475+
@inline function kron!(C::AbstractMatrix, A::Diagonal, B::Diagonal)
477476
valA = A.diag; nA = length(valA)
478477
valB = B.diag; nB = length(valB)
479478
nC = checksquare(C)
480479
@boundscheck nC == nA*nB ||
481480
throw(DimensionMismatch("expect C to be a $(nA*nB)x$(nA*nB) matrix, got size $(nC)x$(nC)"))
482-
481+
isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1]))
483482
@inbounds for i = 1:nA, j = 1:nB
484483
idx = (i-1)*nB+j
485484
C[idx, idx] = valA[i] * valB[j]
@@ -497,9 +496,12 @@ end
497496

498497
@inline function kron!(C::AbstractMatrix, A::Diagonal, B::AbstractMatrix)
499498
Base.require_one_based_indexing(B)
500-
(mA, nA) = size(A); (mB, nB) = size(B); (mC, nC) = size(C);
499+
(mA, nA) = size(A)
500+
(mB, nB) = size(B)
501+
(mC, nC) = size(C)
501502
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
502503
throw(DimensionMismatch("expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)"))
504+
isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1]))
503505
m = 1
504506
@inbounds for j = 1:nA
505507
A_jj = A[j,j]
@@ -517,9 +519,12 @@ end
517519

518520
@inline function kron!(C::AbstractMatrix, A::AbstractMatrix, B::Diagonal)
519521
require_one_based_indexing(A)
520-
(mA, nA) = size(A); (mB, nB) = size(B); (mC, nC) = size(C);
522+
(mA, nA) = size(A)
523+
(mB, nB) = size(B)
524+
(mC, nC) = size(C)
521525
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
522526
throw(DimensionMismatch("expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)"))
527+
isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1]))
523528
m = 1
524529
@inbounds for j = 1:nA
525530
for l = 1:mB
@@ -535,18 +540,6 @@ end
535540
return C
536541
end
537542

538-
function kron(A::Diagonal{T}, B::AbstractMatrix{S}) where {T<:Number, S<:Number}
539-
(mA, nA) = size(A); (mB, nB) = size(B)
540-
R = zeros(Base.promote_op(*, T, S), mA * mB, nA * nB)
541-
return @inbounds kron!(R, A, B)
542-
end
543-
544-
function kron(A::AbstractMatrix{T}, B::Diagonal{S}) where {T<:Number, S<:Number}
545-
(mA, nA) = size(A); (mB, nB) = size(B)
546-
R = zeros(promote_op(*, T, S), mA * mB, nA * nB)
547-
return @inbounds kron!(R, A, B)
548-
end
549-
550543
conj(D::Diagonal) = Diagonal(conj(D.diag))
551544
transpose(D::Diagonal{<:Number}) = D
552545
transpose(D::Diagonal) = Diagonal(transpose.(D.diag))

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,10 @@ Random.seed!(1)
299299
M4 = rand(elty, n÷2, n÷2)
300300
@test kron(D3, M4) kron(DM3, M4)
301301
@test kron(M4, D3) kron(M4, DM3)
302+
X = [ones(1,1) for i in 1:2, j in 1:2]
303+
@test kron(I(2), X)[1,3] == zeros(1,1)
304+
X = [ones(2,2) for i in 1:2, j in 1:2]
305+
@test kron(I(2), X)[1,3] == zeros(2,2)
302306
end
303307
@testset "iszero, isone, triu, tril" begin
304308
Dzero = Diagonal(zeros(elty, 10))

0 commit comments

Comments
 (0)