diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index aefaf16337d83..5b7264558f9ae 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -472,10 +472,14 @@ const BiTri = Union{Bidiagonal,Tridiagonal} @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) @inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) -@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractMatrix, alpha::Number, beta::Number) = - @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) -@inline _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) = - @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) +for T in (:AbstractMatrix, :Diagonal) + @eval begin + @inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::$T, alpha::Number, beta::Number) = + @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) + @inline _mul!(C::AbstractMatrix, A::$T, B::BandedMatrix, alpha::Number, beta::Number) = + @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) + end +end @inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) @@ -831,6 +835,8 @@ function __bibimul!(C, A::Bidiagonal, B::Bidiagonal, _add) C end +_mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number) = + @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul) require_one_based_indexing(C) check_A_mul_B!_sizes(size(C), size(A), size(B)) @@ -1067,6 +1073,8 @@ function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::Bidiagonal, _add::MulAdd C end +_mul!(C::AbstractMatrix, A::Diagonal, B::BiTriSym, alpha::Number, beta::Number) = + @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) _mul!(C::AbstractMatrix, A::Diagonal, B::Bidiagonal, _add::MulAddMul) = _dibimul!(C, A, B, _add) _mul!(C::AbstractMatrix, A::Diagonal, B::TriSym, _add::MulAddMul) = diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 7594e8bca4f56..243df4d82eec2 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -397,13 +397,13 @@ function lmul!(D::Diagonal, T::Tridiagonal) return T end -@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, _add::MulAddMul) +@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, alpha::Number, beta::Number) @inbounds for j in axes(B, 2) @simd for i in axes(B, 1) - _modify!(_add, D.diag[i] * B[i,j], out, (i,j)) + @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B[i,j], out, (i,j)) end end - out + return out end _has_matching_zeros(out::UpperOrUnitUpperTriangular, A::UpperOrUnitUpperTriangular) = true _has_matching_zeros(out::LowerOrUnitLowerTriangular, A::LowerOrUnitLowerTriangular) = true @@ -418,116 +418,118 @@ function _rowrange_tri_stored(B::LowerOrUnitLowerTriangular, col) end _rowrange_tri_zeros(B::UpperOrUnitUpperTriangular, col) = col+1:size(B,1) _rowrange_tri_zeros(B::LowerOrUnitLowerTriangular, col) = 1:col-1 -function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _add::MulAddMul) +function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, alpha::Number, beta::Number) isunit = B isa UnitUpperOrUnitLowerTriangular out_maybeparent, B_maybeparent = _has_matching_zeros(out, B) ? (parent(out), parent(B)) : (out, B) for j in axes(B, 2) # store the diagonal separately for unit triangular matrices if isunit - @inbounds _modify!(_add, D.diag[j] * B[j,j], out, (j,j)) + @inbounds @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[j] * B[j,j], out, (j,j)) end # The indices of out corresponding to the stored indices of B rowrange = _rowrange_tri_stored(B, j) @inbounds @simd for i in rowrange - _modify!(_add, D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j)) + @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j)) end # Fill the indices of out corresponding to the zeros of B # we only fill these if out and B don't have matching zeros if !_has_matching_zeros(out, B) rowrange = _rowrange_tri_zeros(B, j) @inbounds @simd for i in rowrange - _modify!(_add, D.diag[i] * B[i,j], out, (i,j)) + @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B[i,j], out, (i,j)) end end end return out end -@inline function __muldiag_nonzeroalpha!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} - beta = _add.beta - _add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta) +@inline function __muldiag_nonzeroalpha_right!(out, A, D::Diagonal, alpha::Number, beta::Number) @inbounds for j in axes(A, 2) - dja = _add(D.diag[j]) + dja = @stable_muladdmul MulAddMul(alpha,false)(D.diag[j]) @simd for i in axes(A, 1) - _modify!(_add_aisone, A[i,j] * dja, out, (i,j)) + @stable_muladdmul _modify!(MulAddMul(true,beta), A[i,j] * dja, out, (i,j)) end end - out + return out +end + +function __muldiag_nonzeroalpha!(out, A, D::Diagonal, alpha::Number, beta::Number) + __muldiag_nonzeroalpha_right!(out, A, D, alpha, beta) end -function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0} +function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, alpha::Number, beta::Number) isunit = A isa UnitUpperOrUnitLowerTriangular - beta = _add.beta - # since alpha is multiplied to the diagonal element of D, - # we may skip alpha in the second multiplication by setting ais1 to true - _add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta) # if both A and out have the same upper/lower triangular structure, # we may directly read and write from the parents - out_maybeparent, A_maybeparent = _has_matching_zeros(out, A) ? (parent(out), parent(A)) : (out, A) + out_maybeparent, A_maybeparent = _has_matching_zeros(out, A) ? (parent(out), parent(A)) : (out, A) for j in axes(A, 2) - dja = _add(@inbounds D.diag[j]) + dja = @stable_muladdmul MulAddMul(alpha,false)(@inbounds D.diag[j]) # store the diagonal separately for unit triangular matrices if isunit - @inbounds _modify!(_add_aisone, A[j,j] * dja, out, (j,j)) + # since alpha is multiplied to the diagonal element of D, + # we may skip alpha in the second multiplication by setting ais1 to true + @inbounds @stable_muladdmul _modify!(MulAddMul(true,beta), A[j,j] * dja, out, (j,j)) end # indices of out corresponding to the stored indices of A rowrange = _rowrange_tri_stored(A, j) @inbounds @simd for i in rowrange - _modify!(_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j)) + # since alpha is multiplied to the diagonal element of D, + # we may skip alpha in the second multiplication by setting ais1 to true + @stable_muladdmul _modify!(MulAddMul(true,beta), A_maybeparent[i,j] * dja, out_maybeparent, (i,j)) end # Fill the indices of out corresponding to the zeros of A # we only fill these if out and A don't have matching zeros if !_has_matching_zeros(out, A) rowrange = _rowrange_tri_zeros(A, j) @inbounds @simd for i in rowrange - _modify!(_add_aisone, A[i,j] * dja, out, (i,j)) + @stable_muladdmul _modify!(MulAddMul(true,beta), A[i,j] * dja, out, (i,j)) end end end - out + return out +end + +# ambiguity resolution +function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, alpha::Number, beta::Number) + __muldiag_nonzeroalpha_right!(out, D1, D2, alpha, beta) end -@inline function __muldiag_nonzeroalpha!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul) +@inline function __muldiag_nonzeroalpha!(out::Diagonal, D1::Diagonal, D2::Diagonal, alpha::Number, beta::Number) d1 = D1.diag d2 = D2.diag outd = out.diag @inbounds @simd for i in eachindex(d1, d2, outd) - _modify!(_add, d1[i] * d2[i], outd, i) + @stable_muladdmul _modify!(MulAddMul(alpha,beta), d1[i] * d2[i], outd, i) end - out -end - -# ambiguity resolution -@inline function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul) - @inbounds for j in axes(D2, 2), i in axes(D2, 1) - _modify!(_add, D1.diag[i] * D2[i,j], out, (i,j)) - end - out + return out end -# muldiag mainly handles the zero-alpha case, so that we need only +# muldiag handles the zero-alpha case, so that we need only # specialize the non-trivial case -function _mul_diag!(out, A, B, _add) +function _mul_diag!(out, A, B, alpha, beta) require_one_based_indexing(out, A, B) _muldiag_size_check(size(out), size(A), size(B)) - alpha, beta = _add.alpha, _add.beta if iszero(alpha) _rmul_or_fill!(out, beta) else - __muldiag_nonzeroalpha!(out, A, B, _add) + __muldiag_nonzeroalpha!(out, A, B, alpha, beta) end return out end -_mul!(out::AbstractVecOrMat, D::Diagonal, V::AbstractVector, _add) = - _mul_diag!(out, D, V, _add) -_mul!(out::AbstractMatrix, D::Diagonal, B::AbstractMatrix, _add) = - _mul_diag!(out, D, B, _add) -_mul!(out::AbstractMatrix, A::AbstractMatrix, D::Diagonal, _add) = - _mul_diag!(out, A, D, _add) -_mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, _add) = - _mul_diag!(C, Da, Db, _add) -_mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, _add) = - _mul_diag!(C, Da, Db, _add) +_mul!(out::AbstractVector, D::Diagonal, V::AbstractVector, alpha::Number, beta::Number) = + _mul_diag!(out, D, V, alpha, beta) +_mul!(out::AbstractMatrix, D::Diagonal, V::AbstractVector, alpha::Number, beta::Number) = + _mul_diag!(out, D, V, alpha, beta) +for MT in (:AbstractMatrix, :AbstractTriangular) + @eval begin + _mul!(out::AbstractMatrix, D::Diagonal, B::$MT, alpha::Number, beta::Number) = + _mul_diag!(out, D, B, alpha, beta) + _mul!(out::AbstractMatrix, A::$MT, D::Diagonal, alpha::Number, beta::Number) = + _mul_diag!(out, A, D, alpha, beta) + end +end +_mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) = + _mul_diag!(C, Da, Db, alpha, beta) function (*)(Da::Diagonal, A::AbstractMatrix, Db::Diagonal) _muldiag_size_check(size(Da), size(A))