Skip to content

Commit

Permalink
Loosen signature in triangular solver from Strided- to AbstractMatrix.
Browse files Browse the repository at this point in the history
Remove many unnecessary type parameters. Fixes #16196
  • Loading branch information
andreasnoack committed May 5, 2016
1 parent a0f4920 commit 69e6e28
Showing 5 changed files with 72 additions and 17 deletions.
4 changes: 3 additions & 1 deletion base/linalg/bidiag.jl
Original file line number Diff line number Diff line change
@@ -225,7 +225,9 @@ end
==(A::Bidiagonal, B::Bidiagonal) = (A.dv==B.dv) && (A.ev==B.ev) && (A.isupper==B.isupper)

SpecialMatrix = Union{Bidiagonal, SymTridiagonal, Tridiagonal, AbstractTriangular}
*(A::SpecialMatrix, B::SpecialMatrix)=full(A)*full(B)
# to avoid ambiguity warning, but shouldn't be necessary
*(A::AbstractTriangular, B::SpecialMatrix) = full(A) * full(B)
*(A::SpecialMatrix, B::SpecialMatrix) = full(A) * full(B)

#Generic multiplication
for func in (:*, :Ac_mul_B, :A_mul_Bc, :/, :A_rdiv_Bc)
16 changes: 16 additions & 0 deletions base/linalg/diagonal.jl
Original file line number Diff line number Diff line change
@@ -106,6 +106,22 @@ end
/{T<:Number}(D::Diagonal, x::T) = Diagonal(D.diag / x)
*(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag .* Db.diag)
*(D::Diagonal, V::AbstractVector) = D.diag .* V
# To avoid ambiguity in the definitions below
for uplo in (:LowerTriangular, :UpperTriangular)
@eval begin
(*)(A::$uplo, D::Diagonal) = $uplo(A.data * D)

function (*)(A::$(Symbol(:Unit, uplo)), D::Diagonal)
B = A.data * D
for i = 1:size(A, 1)
B[i,i] = D.diag[i]
end
return $uplo(B)
end
end
end
(*)(A::AbstractTriangular, D::Diagonal) =
error("this method should never get called. Please make a bug report.")
*(A::AbstractMatrix, D::Diagonal) =
scale!(similar(A, promote_op(*, eltype(A), eltype(D.diag))), A, D.diag)
*(D::Diagonal, A::AbstractMatrix) =
62 changes: 48 additions & 14 deletions base/linalg/triangular.jl
Original file line number Diff line number Diff line change
@@ -1287,6 +1287,27 @@ function A_rdiv_Bt!(A::StridedMatrix, B::UnitLowerTriangular)
A
end

for f in (:A_rdiv_B!, :A_rdiv_Bc!, :A_rdiv_Bt!)
for (uplo, fuplo) in ((:Lower, :tril!), (:Upper, :triu!))
mat = Symbol(uplo, :Triangular)
umat = Symbol(:Unit, mat)
@eval begin
$f(A::$mat, B::Union{$mat,$umat}) = ($mat)($f($fuplo(A.data), B))
end
end
@eval $f(A::AbstractTriangular, B::AbstractTriangular) = $f(full!(A), B)
end
for f in (:A_ldiv_B!, :Ac_ldiv_B!, :At_ldiv_B!)
for (uplo, fuplo) in ((:Lower, :tril!), (:Upper, :triu!))
mat = Symbol(uplo, :Triangular)
umat = Symbol(:Unit, mat)
@eval begin
$f(A::Union{$mat,$umat}, B::$mat) = ($mat)($f(A, $fuplo(B.data)))
end
end
@eval $f(A::AbstractTriangular, B::AbstractTriangular) = $f(A, full!(B))
end

# Promotion
## Promotion methods in matmul don't apply to triangular multiplication since it is inplace. Hence we have to make very similar definitions, but without allocation of a result array. For multiplication and unit diagonal division the element type doesn't have to be stable under division whereas that is necessary in the general triangular solve problem.

@@ -1297,72 +1318,85 @@ for t in (UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriang
end
end

for f in (:*, :Ac_mul_B, :At_mul_B, :\, :Ac_ldiv_B, :At_ldiv_B)
# for f in (:*, :Ac_mul_B, :At_mul_B, :\, :Ac_ldiv_B, :At_ldiv_B)
# @eval begin
# ($f)(A::AbstractTriangular, B::AbstractTriangular) = ($f)(A, full(B))
# end
# end
for f in (:*, :Ac_mul_B, :At_mul_B)
@eval begin
($f)(A::AbstractTriangular, B::AbstractTriangular) = ($f)(A, full(B))
end
end
for f in (:A_mul_Bc, :A_mul_Bt, :Ac_mul_Bc, :At_mul_Bt, :/, :A_rdiv_Bc, :A_rdiv_Bt)
# for f in (:A_mul_Bc, :A_mul_Bt, :Ac_mul_Bc, :At_mul_Bt, :/, :A_rdiv_Bc, :A_rdiv_Bt)
# @eval begin
# ($f)(A::AbstractTriangular, B::AbstractTriangular) = ($f)(full(A), B)
# end
# end
for f in (:A_mul_Bc, :A_mul_Bt, :Ac_mul_Bc, :At_mul_Bt)
@eval begin
($f)(A::AbstractTriangular, B::AbstractTriangular) = ($f)(full(A), B)
end
end

## The general promotion methods
for mat in (:AbstractVector, AbstractMatrix)
### Multiplication with triangle to the left and hence rhs cannot be transposed.
for (f, g) in ((:*, :A_mul_B!), (:Ac_mul_B, :Ac_mul_B!), (:At_mul_B, :At_mul_B!))
@eval begin
function ($f){TA,TB}(A::AbstractTriangular{TA}, B::StridedVecOrMat{TB})
TAB = typeof(zero(TA)*zero(TB) + zero(TA)*zero(TB))
function ($f)(A::AbstractTriangular, B::$mat)
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
($g)(convert(AbstractArray{TAB}, A), copy_oftype(B, TAB))
end
end
end
### Left division with triangle to the left hence rhs cannot be transposed. No quotients.
for (f, g) in ((:\, :A_ldiv_B!), (:Ac_ldiv_B, :Ac_ldiv_B!), (:At_ldiv_B, :At_ldiv_B!))
@eval begin
function ($f){TA,TB,S}(A::Union{UnitUpperTriangular{TA,S},UnitLowerTriangular{TA,S}}, B::StridedVecOrMat{TB})
TAB = typeof(zero(TA)*zero(TB) + zero(TA)*zero(TB))
function ($f)(A::Union{UnitUpperTriangular,UnitLowerTriangular}, B::$mat)
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
($g)(convert(AbstractArray{TAB}, A), copy_oftype(B, TAB))
end
end
end
### Left division with triangle to the left hence rhs cannot be transposed. Quotients.
for (f, g) in ((:\, :A_ldiv_B!), (:Ac_ldiv_B, :Ac_ldiv_B!), (:At_ldiv_B, :At_ldiv_B!))
@eval begin
function ($f){TA,TB,S}(A::Union{UpperTriangular{TA,S},LowerTriangular{TA,S}}, B::StridedVecOrMat{TB})
TAB = typeof((zero(TA)*zero(TB) + zero(TA)*zero(TB))/one(TA))
function ($f)(A::Union{UpperTriangular,LowerTriangular}, B::$mat)
TAB = typeof((zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))/one(eltype(A)))
($g)(convert(AbstractArray{TAB}, A), copy_oftype(B, TAB))
end
end
end
### Multiplication with triangle to the rigth and hence lhs cannot be transposed.
for (f, g) in ((:*, :A_mul_B!), (:A_mul_Bc, :A_mul_Bc!), (:A_mul_Bt, :A_mul_Bt!))
@eval begin
function ($f){TA,TB}(A::StridedVecOrMat{TA}, B::AbstractTriangular{TB})
TAB = typeof(zero(TA)*zero(TB) + zero(TA)*zero(TB))
function ($f)(A::$mat, B::AbstractTriangular)
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
($g)(copy_oftype(A, TAB), convert(AbstractArray{TAB}, B))
end
end
end
### Right division with triangle to the right hence lhs cannot be transposed. No quotients.
for (f, g) in ((:/, :A_rdiv_B!), (:A_rdiv_Bc, :A_rdiv_Bc!), (:A_rdiv_Bt, :A_rdiv_Bt!))
@eval begin
function ($f){TA,TB,S}(A::StridedVecOrMat{TA}, B::Union{UnitUpperTriangular{TB,S},UnitLowerTriangular{TB,S}})
TAB = typeof(zero(TA)*zero(TB) + zero(TA)*zero(TB))
function ($f)(A::$mat, B::Union{UnitUpperTriangular,UnitLowerTriangular})
TAB = typeof(zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))
($g)(copy_oftype(A, TAB), convert(AbstractArray{TAB}, B))
end
end
end

### Right division with triangle to the right hence lhs cannot be transposed. Quotients.
for (f, g) in ((:/, :A_rdiv_B!), (:A_rdiv_Bc, :A_rdiv_Bc!), (:A_rdiv_Bt, :A_rdiv_Bt!))
@eval begin
function ($f){TA,TB,S}(A::StridedVecOrMat{TA}, B::Union{UpperTriangular{TB,S},LowerTriangular{TB,S}})
TAB = typeof((zero(TA)*zero(TB) + zero(TA)*zero(TB))/one(TA))
function ($f)(A::$mat, B::Union{UpperTriangular,LowerTriangular})
TAB = typeof((zero(eltype(A))*zero(eltype(B)) + zero(eltype(A))*zero(eltype(B)))/one(eltype(A)))
($g)(copy_oftype(A, TAB), convert(AbstractArray{TAB}, B))
end
end
end
end
### Fallbacks brought in from linalg/bidiag.jl while fixing #14506.
# Eventually the above promotion methods should be generalized as
# was done for bidiagonal matrices in #14506.
4 changes: 2 additions & 2 deletions base/sparse/sparsevector.jl
Original file line number Diff line number Diff line change
@@ -1525,7 +1525,7 @@ for isunittri in (true, false), islowertri in (true, false)
(true, :(Ac_ldiv_B), :(Ac_ldiv_B!)) )

# broad method where elements are Numbers
@eval function ($func){TA<:Number,Tb<:Number,S}(A::$tritype{TA,S}, b::SparseVector{Tb})
@eval function ($func){TA<:Number,Tb<:Number,S<:AbstractMatrix}(A::$tritype{TA,S}, b::SparseVector{Tb})
TAb = $(isunittri ?
:(typeof(zero(TA)*zero(Tb) + zero(TA)*zero(Tb))) :
:(typeof((zero(TA)*zero(Tb) + zero(TA)*zero(Tb))/one(TA))) )
@@ -1551,7 +1551,7 @@ for isunittri in (true, false), islowertri in (true, false)
end

# fallback where elements are not Numbers
@eval ($func){TA,Tb,S}(A::$tritype{TA,S}, b::SparseVector{Tb}) = ($ipfunc)(A, copy(b))
@eval ($func)(A::$tritype, b::SparseVector) = ($ipfunc)(A, copy(b))
end

# build in-place left-division operations
3 changes: 3 additions & 0 deletions test/linalg/triangular.jl
Original file line number Diff line number Diff line change
@@ -485,3 +485,6 @@ let
@test_throws DimensionMismatch A_rdiv_Bt!(A, UnitLowerTriangular(B))
@test_throws DimensionMismatch A_rdiv_Bt!(A, UnitUpperTriangular(B))
end

# Issue 16196
@test UpperTriangular(eye(3)) \ sub(ones(3), [1,2,3]) == ones(3)

0 comments on commit 69e6e28

Please sign in to comment.