Skip to content

Commit

Permalink
Fix kron with Diagonal (#40509)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch authored Apr 18, 2021
1 parent 36a048c commit a20e547
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 17 deletions.
27 changes: 10 additions & 17 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -472,14 +472,13 @@ rdiv!(A::AbstractMatrix{T}, transD::Transpose{<:Any,<:Diagonal{T}}) where {T} =
invoke(\, Tuple{Union{QR,QRCompactWY,QRPivoted}, AbstractVecOrMat}, A, B)


@inline function kron!(C::AbstractMatrix{T}, A::Diagonal, B::Diagonal) where T
fill!(C, zero(T))
@inline function kron!(C::AbstractMatrix, A::Diagonal, B::Diagonal)
valA = A.diag; nA = length(valA)
valB = B.diag; nB = length(valB)
nC = checksquare(C)
@boundscheck nC == nA*nB ||
throw(DimensionMismatch("expect C to be a $(nA*nB)x$(nA*nB) matrix, got size $(nC)x$(nC)"))

isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1]))
@inbounds for i = 1:nA, j = 1:nB
idx = (i-1)*nB+j
C[idx, idx] = valA[i] * valB[j]
Expand All @@ -497,9 +496,12 @@ end

@inline function kron!(C::AbstractMatrix, A::Diagonal, B::AbstractMatrix)
Base.require_one_based_indexing(B)
(mA, nA) = size(A); (mB, nB) = size(B); (mC, nC) = size(C);
(mA, nA) = size(A)
(mB, nB) = size(B)
(mC, nC) = size(C)
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
throw(DimensionMismatch("expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)"))
isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1]))
m = 1
@inbounds for j = 1:nA
A_jj = A[j,j]
Expand All @@ -517,9 +519,12 @@ end

@inline function kron!(C::AbstractMatrix, A::AbstractMatrix, B::Diagonal)
require_one_based_indexing(A)
(mA, nA) = size(A); (mB, nB) = size(B); (mC, nC) = size(C);
(mA, nA) = size(A)
(mB, nB) = size(B)
(mC, nC) = size(C)
@boundscheck (mC, nC) == (mA * mB, nA * nB) ||
throw(DimensionMismatch("expect C to be a $(mA * mB)x$(nA * nB) matrix, got size $(mC)x$(nC)"))
isempty(A) || isempty(B) || fill!(C, zero(A[1,1] * B[1,1]))
m = 1
@inbounds for j = 1:nA
for l = 1:mB
Expand All @@ -535,18 +540,6 @@ end
return C
end

function kron(A::Diagonal{T}, B::AbstractMatrix{S}) where {T<:Number, S<:Number}
(mA, nA) = size(A); (mB, nB) = size(B)
R = zeros(Base.promote_op(*, T, S), mA * mB, nA * nB)
return @inbounds kron!(R, A, B)
end

function kron(A::AbstractMatrix{T}, B::Diagonal{S}) where {T<:Number, S<:Number}
(mA, nA) = size(A); (mB, nB) = size(B)
R = zeros(promote_op(*, T, S), mA * mB, nA * nB)
return @inbounds kron!(R, A, B)
end

conj(D::Diagonal) = Diagonal(conj(D.diag))
transpose(D::Diagonal{<:Number}) = D
transpose(D::Diagonal) = Diagonal(transpose.(D.diag))
Expand Down
4 changes: 4 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,10 @@ Random.seed!(1)
M4 = rand(elty, n÷2, n÷2)
@test kron(D3, M4) kron(DM3, M4)
@test kron(M4, D3) kron(M4, DM3)
X = [ones(1,1) for i in 1:2, j in 1:2]
@test kron(I(2), X)[1,3] == zeros(1,1)
X = [ones(2,2) for i in 1:2, j in 1:2]
@test kron(I(2), X)[1,3] == zeros(2,2)
end
@testset "iszero, isone, triu, tril" begin
Dzero = Diagonal(zeros(elty, 10))
Expand Down

0 comments on commit a20e547

Please sign in to comment.