Skip to content

Commit

Permalink
Fix performance issue with diagonal multiplication (#44651)
Browse files Browse the repository at this point in the history
(cherry picked from commit 03af781)
  • Loading branch information
dkarrasch authored and KristofferC committed Apr 19, 2022
1 parent d49e065 commit 477dfb9
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 37 deletions.
99 changes: 64 additions & 35 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ end
(*)(D::Diagonal, A::AbstractMatrix) =
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), D, A)

rmul!(A::AbstractMatrix, D::Diagonal) = mul!(A, A, D)
lmul!(D::Diagonal, B::AbstractVecOrMat) = mul!(B, D, B)
rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D)
lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B)

#TODO: It seems better to call (D' * adjA')' directly?
function *(adjA::Adjoint{<:Any,<:AbstractMatrix}, D::Diagonal)
Expand Down Expand Up @@ -285,35 +285,80 @@ function *(D::Diagonal, transA::Transpose{<:Any,<:AbstractMatrix})
end

@inline function __muldiag!(out, D::Diagonal, B, alpha, beta)
if iszero(beta)
out .= (D.diag .* B) .*ₛ alpha
require_one_based_indexing(out)
if iszero(alpha)
_rmul_or_fill!(out, beta)
else
out .= (D.diag .* B) .*ₛ alpha .+ out .* beta
if iszero(beta)
@inbounds for j in axes(B, 2)
@simd for i in axes(B, 1)
out[i,j] = D.diag[i] * B[i,j] * alpha
end
end
else
@inbounds for j in axes(B, 2)
@simd for i in axes(B, 1)
out[i,j] = D.diag[i] * B[i,j] * alpha + out[i,j] * beta
end
end
end
end
return out
end

@inline function __muldiag!(out, A, D::Diagonal, alpha, beta)
if iszero(beta)
out .= (A .* permutedims(D.diag)) .*ₛ alpha
require_one_based_indexing(out)
if iszero(alpha)
_rmul_or_fill!(out, beta)
else
out .= (A .* permutedims(D.diag)) .*ₛ alpha .+ out .* beta
if iszero(beta)
@inbounds for j in axes(A, 2)
dja = D.diag[j] * alpha
@simd for i in axes(A, 1)
out[i,j] = A[i,j] * dja
end
end
else
@inbounds for j in axes(A, 2)
dja = D.diag[j] * alpha
@simd for i in axes(A, 1)
out[i,j] = A[i,j] * dja + out[i,j] * beta
end
end
end
end
return out
end

@inline function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, alpha, beta)
if iszero(beta)
out.diag .= (D1.diag .* D2.diag) .*ₛ alpha
d1 = D1.diag
d2 = D2.diag
if iszero(alpha)
_rmul_or_fill!(out.diag, beta)
else
out.diag .= (D1.diag .* D2.diag) .*ₛ alpha .+ out.diag .* beta
if iszero(beta)
@inbounds @simd for i in eachindex(out.diag)
out.diag[i] = d1[i] * d2[i] * alpha
end
else
@inbounds @simd for i in eachindex(out.diag)
out.diag[i] = d1[i] * d2[i] * alpha + out.diag[i] * beta
end
end
end
return out
end
@inline function __muldiag!(out, D1::Diagonal, D2::Diagonal, alpha, beta)
require_one_based_indexing(out)
mA = size(D1, 1)
d1 = D1.diag
d2 = D2.diag
_rmul_or_fill!(out, beta)
if !iszero(alpha)
@inbounds @simd for i in 1:mA
out[i,i] += d1[i] * d2[i] * alpha
end
end
return out
end

# only needed for ambiguity resolution, as mul! is explicitly defined for these arguments
@inline __muldiag!(out, D1::Diagonal, D2::Diagonal, alpha, beta) =
mul!(out, D1, D2, alpha, beta)

@inline function _muldiag!(out, A, B, alpha, beta)
_muldiag_size_check(out, A, B)
Expand All @@ -340,24 +385,8 @@ end
@inline mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) =
_muldiag!(C, Da, Db, alpha, beta)

function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number)
_muldiag_size_check(C, Da, Db)
require_one_based_indexing(C)
mA = size(Da, 1)
da = Da.diag
db = Db.diag
_rmul_or_fill!(C, beta)
if iszero(beta)
@inbounds @simd for i in 1:mA
C[i,i] = Ref(da[i] * db[i]) .*ₛ alpha
end
else
@inbounds @simd for i in 1:mA
C[i,i] += Ref(da[i] * db[i]) .*ₛ alpha
end
end
return C
end
mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) =
_muldiag!(C, Da, Db, alpha, beta)

_init(op, A::AbstractArray{<:Number}, B::AbstractArray{<:Number}) =
(_ -> zero(typeof(op(oneunit(eltype(A)), oneunit(eltype(B))))))
Expand Down
3 changes: 1 addition & 2 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
# inside this function.
function *end
Broadcast.broadcasted(::typeof(*ₛ), out, beta) =
iszero(beta::Number) ? false :
isone(beta::Number) ? broadcasted(identity, out) : broadcasted(*, out, beta)
iszero(beta::Number) ? false : broadcasted(*, out, beta)

"""
MulAddMul(alpha, beta)
Expand Down

0 comments on commit 477dfb9

Please sign in to comment.