Skip to content

Commit

Permalink
5-arg mul! bug fixes (#32901)
Browse files Browse the repository at this point in the history
* Enforce strong-zero behavior for alpha

* Fix random tests for 5-arg mul! (avoid promotion in vector literal)

* Always test `mul!` against `Matrix`/`Vector`'s `*`

* Fix 5-arg mul! for Bi/Tri/Sym * Diag

* Add tests with adjoint and transpose in random test

* Test strong zero in random test

* Fixing 5-arg mul! for Bi/Tri/Sym * Diag

* Fixing strong zero test; handle UnitLowerTriangular etc.

* Enforcing strong-zero behavior for alpha

* Fixing strong zero test; handle Bidiagonal etc.

* Fixing 5-arg mul! for Bi/Tri/Sym * Diag

* Fixing 5-arg mul! for Bi/Tri/Sym * Diag

* Include values of α and β in testset description

* Handle empty sub-diagonal

* Enforcing strong-zero behavior for alpha

* Short-circuit multiplication by alpha
  • Loading branch information
tkf authored and JeffBezanson committed Aug 16, 2019
1 parent b51ae5b commit b5f4e87
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 89 deletions.
72 changes: 40 additions & 32 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ function A_mul_B_td!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym,
# `_modify!` in the following loop will not update the
# off-diagonal elements for non-zero beta.
_rmul_or_fill!(C, _add.beta)
iszero(_add.alpha) && return C
Al = _diag(A, -1)
Ad = _diag(A, 0)
Au = _diag(A, 1)
Expand Down Expand Up @@ -447,32 +448,33 @@ function A_mul_B_td!(C::AbstractMatrix, A::BiTriSym, B::Diagonal,
_add::MulAddMul = MulAddMul())
check_A_mul_B!_sizes(C, A, B)
n = size(A,1)
n <= 3 && return mul!(C, Array(A), Array(B))
fill!(C, zero(eltype(C)))
n <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta)
_rmul_or_fill!(C, _add.beta) # see the same use above
iszero(_add.alpha) && return C
Al = _diag(A, -1)
Ad = _diag(A, 0)
Au = _diag(A, 1)
Bd = B.diag
@inbounds begin
# first row of C
_modify!(_add, A[1,1]*B[1,1], C, (1,1))
_modify!(_add, A[1,2]*B[2,2], C, (1,2))
C[1,1] += _add(A[1,1]*B[1,1])
C[1,2] += _add(A[1,2]*B[2,2])
# second row of C
_modify!(_add, A[2,1]*B[1,1], C, (2,1))
_modify!(_add, A[2,2]*B[2,2], C, (2,2))
_modify!(_add, A[2,3]*B[3,3], C, (2,3))
C[2,1] += _add(A[2,1]*B[1,1])
C[2,2] += _add(A[2,2]*B[2,2])
C[2,3] += _add(A[2,3]*B[3,3])
for j in 3:n-2
_modify!(_add, Al[j-1]*Bd[j-1], C, (j, j-1))
_modify!(_add, Ad[j ]*Bd[j ], C, (j, j ))
_modify!(_add, Au[j ]*Bd[j+1], C, (j, j+1))
C[j, j-1] += _add(Al[j-1]*Bd[j-1])
C[j, j ] += _add(Ad[j ]*Bd[j ])
C[j, j+1] += _add(Au[j ]*Bd[j+1])
end
# row before last of C
_modify!(_add, A[n-1,n-2]*B[n-2,n-2], C, (n-1,n-2))
_modify!(_add, A[n-1,n-1]*B[n-1,n-1], C, (n-1,n-1))
_modify!(_add, A[n-1, n]*B[n ,n ], C, (n-1,n ))
C[n-1,n-2] += _add(A[n-1,n-2]*B[n-2,n-2])
C[n-1,n-1] += _add(A[n-1,n-1]*B[n-1,n-1])
C[n-1,n ] += _add(A[n-1, n]*B[n ,n ])
# last row of C
_modify!(_add, A[n,n-1]*B[n-1,n-1], C, (n,n-1))
_modify!(_add, A[n,n ]*B[n, n ], C, (n,n ))
C[n,n-1] += _add(A[n,n-1]*B[n-1,n-1])
C[n,n ] += _add(A[n,n ]*B[n, n ])
end # inbounds
C
end
Expand All @@ -489,6 +491,7 @@ function A_mul_B_td!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat,
if size(C,2) != nB
throw(DimensionMismatch("A has second dimension $nA, B has $(size(B,2)), C has $(size(C,2)) but all must match"))
end
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
nA <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta)
l = _diag(A, -1)
d = _diag(A, 0)
Expand All @@ -510,9 +513,12 @@ end
function A_mul_B_td!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym,
_add::MulAddMul = MulAddMul())
check_A_mul_B!_sizes(C, A, B)
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
n = size(A,1)
n <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta)
m = size(B,2)
if n <= 3 || m <= 1
return mul!(C, Array(A), Array(B), _add.alpha, _add.beta)
end
Bl = _diag(B, -1)
Bd = _diag(B, 0)
Bu = _diag(B, 1)
Expand All @@ -539,36 +545,38 @@ function A_mul_B_td!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym,
C
end

function A_mul_B_td!(C::AbstractMatrix, A::Diagonal, B::BiTriSym)
function A_mul_B_td!(C::AbstractMatrix, A::Diagonal, B::BiTriSym,
_add::MulAddMul = MulAddMul())
check_A_mul_B!_sizes(C, A, B)
n = size(A,1)
n <= 3 && return mul!(C, Array(A), Array(B))
fill!(C, zero(eltype(C)))
n <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta)
_rmul_or_fill!(C, _add.beta) # see the same use above
iszero(_add.alpha) && return C
Ad = A.diag
Bl = _diag(B, -1)
Bd = _diag(B, 0)
Bu = _diag(B, 1)
@inbounds begin
# first row of C
C[1,1] = A[1,1]*B[1,1]
C[1,2] = A[1,1]*B[1,2]
C[1,1] += _add(A[1,1]*B[1,1])
C[1,2] += _add(A[1,1]*B[1,2])
# second row of C
C[2,1] = A[2,2]*B[2,1]
C[2,2] = A[2,2]*B[2,2]
C[2,3] = A[2,2]*B[2,3]
C[2,1] += _add(A[2,2]*B[2,1])
C[2,2] += _add(A[2,2]*B[2,2])
C[2,3] += _add(A[2,2]*B[2,3])
for j in 3:n-2
Ajj = Ad[j]
C[j, j-1] = Ajj*Bl[j-1]
C[j, j ] = Ajj*Bd[j]
C[j, j+1] = Ajj*Bu[j]
C[j, j-1] += _add(Ajj*Bl[j-1])
C[j, j ] += _add(Ajj*Bd[j])
C[j, j+1] += _add(Ajj*Bu[j])
end
# row before last of C
C[n-1,n-2] = A[n-1,n-1]*B[n-1,n-2]
C[n-1,n-1] = A[n-1,n-1]*B[n-1,n-1]
C[n-1,n ] = A[n-1,n-1]*B[n-1,n ]
C[n-1,n-2] += _add(A[n-1,n-1]*B[n-1,n-2])
C[n-1,n-1] += _add(A[n-1,n-1]*B[n-1,n-1])
C[n-1,n ] += _add(A[n-1,n-1]*B[n-1,n ])
# last row of C
C[n,n-1] = A[n,n]*B[n,n-1]
C[n,n ] = A[n,n]*B[n,n ]
C[n,n-1] += _add(A[n,n]*B[n,n-1])
C[n,n ] += _add(A[n,n]*B[n,n ])
end # inbounds
C
end
Expand Down
52 changes: 27 additions & 25 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,81 +277,83 @@ function rmul!(A::AbstractMatrix, transB::Transpose{<:Any,<:Diagonal})
end

# Elements of `out` may not be defined (e.g., for `BigFloat`). To make
# `mul!(out, A, B)` work for such cases, `_scaledout` short-circuits
# `mul!(out, A, B)` work for such cases, `out .*ₛ beta` short-circuits
# `out * beta`. Using `broadcasted` to avoid the multiplication
# inside this function.
_scaledout(out, beta) = iszero(beta) ? false : broadcasted(*, out, beta)
function *end
Broadcast.broadcasted(::typeof(*ₛ), out, beta) =
iszero(beta::Number) ? false : broadcasted(*, out, beta)

# Get ambiguous method if try to unify AbstractVector/AbstractMatrix here using AbstractVecOrMat
@inline mul!(out::AbstractVector, A::Diagonal, in::AbstractVector,
alpha::Number, beta::Number) =
out .= A.diag .* in .* alpha .+ _scaledout(out, beta)
out .= (A.diag .* in) .* alpha .+ out .*beta
@inline mul!(out::AbstractVector, A::Adjoint{<:Any,<:Diagonal}, in::AbstractVector,
alpha::Number, beta::Number) =
out .= adjoint.(A.parent.diag) .* in .* alpha .+ _scaledout(out, beta)
out .= (adjoint.(A.parent.diag) .* in) .* alpha .+ out .*beta
@inline mul!(out::AbstractVector, A::Transpose{<:Any,<:Diagonal}, in::AbstractVector,
alpha::Number, beta::Number) =
out .= transpose.(A.parent.diag) .* in .* alpha .+ _scaledout(out, beta)
out .= (transpose.(A.parent.diag) .* in) .* alpha .+ out .*beta

@inline mul!(out::AbstractMatrix, A::Diagonal, in::StridedMatrix,
alpha::Number, beta::Number) =
out .= A.diag .* in .* alpha .+ _scaledout(out, beta)
out .= (A.diag .* in) .* alpha .+ out .*beta
@inline mul!(out::AbstractMatrix, A::Adjoint{<:Any,<:Diagonal}, in::StridedMatrix,
alpha::Number, beta::Number) =
out .= adjoint.(A.parent.diag) .* in .* alpha .+ _scaledout(out, beta)
out .= (adjoint.(A.parent.diag) .* in) .* alpha .+ out .*beta
@inline mul!(out::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, in::StridedMatrix,
alpha::Number, beta::Number) =
out .= transpose.(A.parent.diag) .* in .* alpha .+ _scaledout(out, beta)
out .= (transpose.(A.parent.diag) .* in) .* alpha .+ out .*beta

@inline mul!(out::AbstractMatrix, A::Diagonal, in::Adjoint{<:Any,<:StridedMatrix},
alpha::Number, beta::Number) =
out .= A.diag .* in .* alpha
out .= (A.diag .* in) .* alpha .+ out .*ₛ beta
@inline mul!(out::AbstractMatrix, A::Adjoint{<:Any,<:Diagonal}, in::Adjoint{<:Any,<:StridedMatrix},
alpha::Number, beta::Number) =
out .= adjoint.(A.parent.diag) .* in .* alpha .+ _scaledout(out, beta)
out .= (adjoint.(A.parent.diag) .* in) .* alpha .+ out .*beta
@inline mul!(out::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, in::Adjoint{<:Any,<:StridedMatrix},
alpha::Number, beta::Number) =
out .= transpose.(A.parent.diag) .* in .* alpha .+ _scaledout(out, beta)
out .= (transpose.(A.parent.diag) .* in) .* alpha .+ out .*beta

@inline mul!(out::AbstractMatrix, A::Diagonal, in::Transpose{<:Any,<:StridedMatrix},
alpha::Number, beta::Number) =
out .= A.diag .* in .* alpha .+ _scaledout(out, beta)
out .= (A.diag .* in) .* alpha .+ out .*beta
@inline mul!(out::AbstractMatrix, A::Adjoint{<:Any,<:Diagonal}, in::Transpose{<:Any,<:StridedMatrix},
alpha::Number, beta::Number) =
out .= adjoint.(A.parent.diag) .* in .* alpha .+ _scaledout(out, beta)
out .= (adjoint.(A.parent.diag) .* in) .* alpha .+ out .*beta
@inline mul!(out::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, in::Transpose{<:Any,<:StridedMatrix},
alpha::Number, beta::Number) =
out .= transpose.(A.parent.diag) .* in .* alpha .+ _scaledout(out, beta)
out .= (transpose.(A.parent.diag) .* in) .* alpha .+ out .*beta

@inline mul!(out::AbstractMatrix, in::StridedMatrix, A::Diagonal,
alpha::Number, beta::Number) =
out .= in .* permutedims(A.diag) .* alpha .+ _scaledout(out, beta)
out .= (in .* permutedims(A.diag)) .* alpha .+ out .*beta
@inline mul!(out::AbstractMatrix, in::StridedMatrix, A::Adjoint{<:Any,<:Diagonal},
alpha::Number, beta::Number) =
out .= in .* adjoint(A.parent.diag) .* alpha .+ _scaledout(out, beta)
out .= (in .* adjoint(A.parent.diag)) .* alpha .+ out .*beta
@inline mul!(out::AbstractMatrix, in::StridedMatrix, A::Transpose{<:Any,<:Diagonal},
alpha::Number, beta::Number) =
out .= in .* transpose(A.parent.diag) .* alpha .+ _scaledout(out, beta)
out .= (in .* transpose(A.parent.diag)) .* alpha .+ out .*beta

@inline mul!(out::AbstractMatrix, in::Adjoint{<:Any,<:StridedMatrix}, A::Diagonal,
alpha::Number, beta::Number) =
out .= in .* permutedims(A.diag) .* alpha .+ _scaledout(out, beta)
out .= (in .* permutedims(A.diag)) .* alpha .+ out .*beta
@inline mul!(out::AbstractMatrix, in::Adjoint{<:Any,<:StridedMatrix}, A::Adjoint{<:Any,<:Diagonal},
alpha::Number, beta::Number) =
out .= in .* adjoint(A.parent.diag) .* alpha .+ _scaledout(out, beta)
out .= (in .* adjoint(A.parent.diag)) .* alpha .+ out .*beta
@inline mul!(out::AbstractMatrix, in::Adjoint{<:Any,<:StridedMatrix}, A::Transpose{<:Any,<:Diagonal},
alpha::Number, beta::Number) =
out .= in .* transpose(A.parent.diag) .* alpha .+ _scaledout(out, beta)
out .= (in .* transpose(A.parent.diag)) .* alpha .+ out .*beta

@inline mul!(out::AbstractMatrix, in::Transpose{<:Any,<:StridedMatrix}, A::Diagonal,
alpha::Number, beta::Number) =
out .= in .* permutedims(A.diag) .* alpha .+ _scaledout(out, beta)
out .= (in .* permutedims(A.diag)) .* alpha .+ out .*beta
@inline mul!(out::AbstractMatrix, in::Transpose{<:Any,<:StridedMatrix}, A::Adjoint{<:Any,<:Diagonal},
alpha::Number, beta::Number) =
out .= in .* adjoint(A.parent.diag) .* alpha .+ _scaledout(out, beta)
out .= (in .* adjoint(A.parent.diag)) .* alpha .+ out .*beta
@inline mul!(out::AbstractMatrix, in::Transpose{<:Any,<:StridedMatrix}, A::Transpose{<:Any,<:Diagonal},
alpha::Number, beta::Number) =
out .= in .* transpose(A.parent.diag) .* alpha .+ _scaledout(out, beta)
out .= (in .* transpose(A.parent.diag)) .* alpha .+ out .*beta

# ambiguities with Symmetric/Hermitian
# RealHermSymComplex[Sym]/[Herm] only include Number; invariant to [c]transpose
Expand Down Expand Up @@ -382,11 +384,11 @@ mul!(C::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, B::Transpose{<:Any,<:Rea
@inline mul!(C::AbstractMatrix,
A::Adjoint{<:Any,<:Diagonal}, B::Adjoint{<:Any,<:RealHermSymComplexSym},
alpha::Number, beta::Number) =
C .= adjoint.(A.parent.diag) .* B .* alpha .+ C .* beta
C .= (adjoint.(A.parent.diag) .* B) .* alpha .+ C .* beta
@inline mul!(C::AbstractMatrix,
A::Transpose{<:Any,<:Diagonal}, B::Transpose{<:Any,<:RealHermSymComplexHerm},
alpha::Number, beta::Number) =
C .= transpose.(A.parent.diag) .* B .* alpha .+ C .* beta
C .= (transpose.(A.parent.diag) .* B) .* alpha .+ C .* beta

(/)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag ./ Db.diag)

Expand Down
9 changes: 6 additions & 3 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat
if nC != mA
throw(DimensionMismatch("output matrix has size: $(nC), but should have size $(mA)"))
end
if mA == 0 || nA == 0
if mA == 0 || nA == 0 || iszero(_add.alpha)
return _rmul_or_fill!(C, _add.beta)
end
if mA == 2 && nA == 2
Expand Down Expand Up @@ -541,7 +541,7 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar
throw(ArgumentError("output matrix must not be aliased with input matrix"))
end

if mA == 0 || nA == 0 || nB == 0
if mA == 0 || nA == 0 || nB == 0 || iszero(_add.alpha)
if size(C) != (mA, nB)
throw(DimensionMismatch("C has dimensions $(size(C)), should have ($mA,$nB)"))
end
Expand Down Expand Up @@ -674,6 +674,9 @@ function generic_matmatmul!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::Abs
mB, nB = lapack_size(tB, B)
mC, nC = size(C)

if iszero(_add.alpha)
return _rmul_or_fill!(C, _add.beta)
end
if mA == nA == mB == nB == mC == nC == 2
return matmul2x2!(C, tA, tB, A, B, _add)
end
Expand All @@ -697,7 +700,7 @@ function _generic_matmatmul!(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVecOrMat
if size(C,1) != mA || size(C,2) != nB
throw(DimensionMismatch("result C has dimensions $(size(C)), needs ($mA,$nB)"))
end
if isempty(A) || isempty(B)
if isempty(A) || isempty(B) || iszero(_add.alpha)
return _rmul_or_fill!(C, _add.beta)
end

Expand Down
8 changes: 8 additions & 0 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ end

@inline function _mul!(A::UpperTriangular, B::UpperTriangular, c::Number, _add::MulAddMul)
n = checksquare(B)
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
for j = 1:n
for i = 1:j
@inbounds _modify!(_add, B[i,j] * c, A, (i,j))
Expand All @@ -468,6 +469,7 @@ end
end
@inline function _mul!(A::UpperTriangular, c::Number, B::UpperTriangular, _add::MulAddMul)
n = checksquare(B)
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
for j = 1:n
for i = 1:j
@inbounds _modify!(_add, c * B[i,j], A, (i,j))
Expand All @@ -477,6 +479,7 @@ end
end
@inline function _mul!(A::UpperTriangular, B::UnitUpperTriangular, c::Number, _add::MulAddMul)
n = checksquare(B)
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
for j = 1:n
@inbounds _modify!(_add, c, A, (j,j))
for i = 1:(j - 1)
Expand All @@ -487,6 +490,7 @@ end
end
@inline function _mul!(A::UpperTriangular, c::Number, B::UnitUpperTriangular, _add::MulAddMul)
n = checksquare(B)
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
for j = 1:n
@inbounds _modify!(_add, c, A, (j,j))
for i = 1:(j - 1)
Expand All @@ -497,6 +501,7 @@ end
end
@inline function _mul!(A::LowerTriangular, B::LowerTriangular, c::Number, _add::MulAddMul)
n = checksquare(B)
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
for j = 1:n
for i = j:n
@inbounds _modify!(_add, B[i,j] * c, A, (i,j))
Expand All @@ -506,6 +511,7 @@ end
end
@inline function _mul!(A::LowerTriangular, c::Number, B::LowerTriangular, _add::MulAddMul)
n = checksquare(B)
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
for j = 1:n
for i = j:n
@inbounds _modify!(_add, c * B[i,j], A, (i,j))
Expand All @@ -515,6 +521,7 @@ end
end
@inline function _mul!(A::LowerTriangular, B::UnitLowerTriangular, c::Number, _add::MulAddMul)
n = checksquare(B)
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
for j = 1:n
@inbounds _modify!(_add, c, A, (j,j))
for i = (j + 1):n
Expand All @@ -525,6 +532,7 @@ end
end
@inline function _mul!(A::LowerTriangular, c::Number, B::UnitLowerTriangular, _add::MulAddMul)
n = checksquare(B)
iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta)
for j = 1:n
@inbounds _modify!(_add, c, A, (j,j))
for i = (j + 1):n
Expand Down
2 changes: 2 additions & 0 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ end

if m == 0
return C
elseif iszero(_add.alpha)
return _rmul_or_fill!(C, _add.beta)
end

α = S.dv
Expand Down
Loading

0 comments on commit b5f4e87

Please sign in to comment.