Skip to content

Commit

Permalink
Replace MulAddMul by alpha,beta in __muldiag (JuliaLang#56360)
Browse files Browse the repository at this point in the history
This PR replaces `MulAddMul` arguments by `alpha, beta` pairs in the
multiplication methods involving `Diagonal` matrices, and constructs the
objects exactly where they are required. Such an approach improves
latency.
```julia
julia> D = Diagonal(1:2000); A = rand(size(D)...); C = similar(A);

julia> @time mul!(C, A, D, 1, 2); # first-run latency is reduced
  0.129741 seconds (180.18 k allocations: 9.607 MiB, 88.87% compilation time) # nightly v"1.12.0-DEV.1505"
  0.083005 seconds (146.68 k allocations: 7.442 MiB, 82.94% compilation time) # this PR

julia> @Btime mul!($C, $A, $D, 1, 2); # runtime performance is unaffected 
  4.983 ms (0 allocations: 0 bytes) # nightly
  4.938 ms (0 allocations: 0 bytes) # this PR
```

This PR sets the stage for a similar change for
`Bidiagonal`/`Tridiaognal` matrices, which would lead to a bigger
reduction in latencies.
  • Loading branch information
jishnub authored Nov 11, 2024
1 parent ad24368 commit 1e0cee5
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 53 deletions.
16 changes: 12 additions & 4 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -472,10 +472,14 @@ const BiTri = Union{Bidiagonal,Tridiagonal}
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractMatrix, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
for T in (:AbstractMatrix, :Diagonal)
@eval begin
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::$T, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
@inline _mul!(C::AbstractMatrix, A::$T, B::BandedMatrix, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
end
end
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))

Expand Down Expand Up @@ -831,6 +835,8 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add)
C
end

_mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul)
require_one_based_indexing(C)
check_A_mul_B!_sizes(size(C), size(A), size(B))
Expand Down Expand Up @@ -1067,6 +1073,8 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAdd
C
end

_mul!(C::AbstractMatrix, A::Diagonal, B::BiTriSym, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
_mul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add::MulAddMul) =
_dibimul!(C, A, B, _add)
_mul!(C::AbstractMatrix, A::Diagonal, B::TriSym, _add::MulAddMul) =
Expand Down
100 changes: 51 additions & 49 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -397,13 +397,13 @@ function lmul!(D::Diagonal, T::Tridiagonal)
return T
end

@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, _add::MulAddMul)
@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, alpha::Number, beta::Number)
@inbounds for j in axes(B, 2)
@simd for i in axes(B, 1)
_modify!(_add, D.diag[i] * B[i,j], out, (i,j))
@stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B[i,j], out, (i,j))
end
end
out
return out
end
_has_matching_zeros(out::UpperOrUnitUpperTriangular, A::UpperOrUnitUpperTriangular) = true
_has_matching_zeros(out::LowerOrUnitLowerTriangular, A::LowerOrUnitLowerTriangular) = true
Expand All @@ -418,116 +418,118 @@ function _rowrange_tri_stored(B::LowerOrUnitLowerTriangular, col)
end
_rowrange_tri_zeros(B::UpperOrUnitUpperTriangular, col) = col+1:size(B,1)
_rowrange_tri_zeros(B::LowerOrUnitLowerTriangular, col) = 1:col-1
function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _add::MulAddMul)
function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, alpha::Number, beta::Number)
isunit = B isa UnitUpperOrUnitLowerTriangular
out_maybeparent, B_maybeparent = _has_matching_zeros(out, B) ? (parent(out), parent(B)) : (out, B)
for j in axes(B, 2)
# store the diagonal separately for unit triangular matrices
if isunit
@inbounds _modify!(_add, D.diag[j] * B[j,j], out, (j,j))
@inbounds @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[j] * B[j,j], out, (j,j))
end
# The indices of out corresponding to the stored indices of B
rowrange = _rowrange_tri_stored(B, j)
@inbounds @simd for i in rowrange
_modify!(_add, D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
@stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
end
# Fill the indices of out corresponding to the zeros of B
# we only fill these if out and B don't have matching zeros
if !_has_matching_zeros(out, B)
rowrange = _rowrange_tri_zeros(B, j)
@inbounds @simd for i in rowrange
_modify!(_add, D.diag[i] * B[i,j], out, (i,j))
@stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B[i,j], out, (i,j))
end
end
end
return out
end

@inline function __muldiag_nonzeroalpha!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
beta = _add.beta
_add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
@inline function __muldiag_nonzeroalpha_right!(out, A, D::Diagonal, alpha::Number, beta::Number)
@inbounds for j in axes(A, 2)
dja = _add(D.diag[j])
dja = @stable_muladdmul MulAddMul(alpha,false)(D.diag[j])
@simd for i in axes(A, 1)
_modify!(_add_aisone, A[i,j] * dja, out, (i,j))
@stable_muladdmul _modify!(MulAddMul(true,beta), A[i,j] * dja, out, (i,j))
end
end
out
return out
end

function __muldiag_nonzeroalpha!(out, A, D::Diagonal, alpha::Number, beta::Number)
__muldiag_nonzeroalpha_right!(out, A, D, alpha, beta)
end
function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, alpha::Number, beta::Number)
isunit = A isa UnitUpperOrUnitLowerTriangular
beta = _add.beta
# since alpha is multiplied to the diagonal element of D,
# we may skip alpha in the second multiplication by setting ais1 to true
_add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
# if both A and out have the same upper/lower triangular structure,
# we may directly read and write from the parents
out_maybeparent, A_maybeparent = _has_matching_zeros(out, A) ? (parent(out), parent(A)) : (out, A)
out_maybeparent, A_maybeparent = _has_matching_zeros(out, A) ? (parent(out), parent(A)) : (out, A)
for j in axes(A, 2)
dja = _add(@inbounds D.diag[j])
dja = @stable_muladdmul MulAddMul(alpha,false)(@inbounds D.diag[j])
# store the diagonal separately for unit triangular matrices
if isunit
@inbounds _modify!(_add_aisone, A[j,j] * dja, out, (j,j))
# since alpha is multiplied to the diagonal element of D,
# we may skip alpha in the second multiplication by setting ais1 to true
@inbounds @stable_muladdmul _modify!(MulAddMul(true,beta), A[j,j] * dja, out, (j,j))
end
# indices of out corresponding to the stored indices of A
rowrange = _rowrange_tri_stored(A, j)
@inbounds @simd for i in rowrange
_modify!(_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
# since alpha is multiplied to the diagonal element of D,
# we may skip alpha in the second multiplication by setting ais1 to true
@stable_muladdmul _modify!(MulAddMul(true,beta), A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
end
# Fill the indices of out corresponding to the zeros of A
# we only fill these if out and A don't have matching zeros
if !_has_matching_zeros(out, A)
rowrange = _rowrange_tri_zeros(A, j)
@inbounds @simd for i in rowrange
_modify!(_add_aisone, A[i,j] * dja, out, (i,j))
@stable_muladdmul _modify!(MulAddMul(true,beta), A[i,j] * dja, out, (i,j))
end
end
end
out
return out
end

# ambiguity resolution
function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, alpha::Number, beta::Number)
__muldiag_nonzeroalpha_right!(out, D1, D2, alpha, beta)
end

@inline function __muldiag_nonzeroalpha!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
@inline function __muldiag_nonzeroalpha!(out::Diagonal, D1::Diagonal, D2::Diagonal, alpha::Number, beta::Number)
d1 = D1.diag
d2 = D2.diag
outd = out.diag
@inbounds @simd for i in eachindex(d1, d2, outd)
_modify!(_add, d1[i] * d2[i], outd, i)
@stable_muladdmul _modify!(MulAddMul(alpha,beta), d1[i] * d2[i], outd, i)
end
out
end

# ambiguity resolution
@inline function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
@inbounds for j in axes(D2, 2), i in axes(D2, 1)
_modify!(_add, D1.diag[i] * D2[i,j], out, (i,j))
end
out
return out
end

# muldiag mainly handles the zero-alpha case, so that we need only
# muldiag handles the zero-alpha case, so that we need only
# specialize the non-trivial case
function _mul_diag!(out, A, B, _add)
function _mul_diag!(out, A, B, alpha, beta)
require_one_based_indexing(out, A, B)
_muldiag_size_check(size(out), size(A), size(B))
alpha, beta = _add.alpha, _add.beta
if iszero(alpha)
_rmul_or_fill!(out, beta)
else
__muldiag_nonzeroalpha!(out, A, B, _add)
__muldiag_nonzeroalpha!(out, A, B, alpha, beta)
end
return out
end

_mul!(out::AbstractVecOrMat, D::Diagonal, V::AbstractVector, _add) =
_mul_diag!(out, D, V, _add)
_mul!(out::AbstractMatrix, D::Diagonal, B::AbstractMatrix, _add) =
_mul_diag!(out, D, B, _add)
_mul!(out::AbstractMatrix, A::AbstractMatrix, D::Diagonal, _add) =
_mul_diag!(out, A, D, _add)
_mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, _add) =
_mul_diag!(C, Da, Db, _add)
_mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, _add) =
_mul_diag!(C, Da, Db, _add)
_mul!(out::AbstractVector, D::Diagonal, V::AbstractVector, alpha::Number, beta::Number) =
_mul_diag!(out, D, V, alpha, beta)
_mul!(out::AbstractMatrix, D::Diagonal, V::AbstractVector, alpha::Number, beta::Number) =
_mul_diag!(out, D, V, alpha, beta)
for MT in (:AbstractMatrix, :AbstractTriangular)
@eval begin
_mul!(out::AbstractMatrix, D::Diagonal, B::$MT, alpha::Number, beta::Number) =
_mul_diag!(out, D, B, alpha, beta)
_mul!(out::AbstractMatrix, A::$MT, D::Diagonal, alpha::Number, beta::Number) =
_mul_diag!(out, A, D, alpha, beta)
end
end
_mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) =
_mul_diag!(C, Da, Db, alpha, beta)

function (*)(Da::Diagonal, A::AbstractMatrix, Db::Diagonal)
_muldiag_size_check(size(Da), size(A))
Expand Down

0 comments on commit 1e0cee5

Please sign in to comment.