Skip to content

Commit

Permalink
Broadcast binary ops involving strided triangular (#55798)
Browse files Browse the repository at this point in the history
Currently, we evaluate expressions like `(A::UpperTriangular) +
(B::UpperTriangular)` using broadcasting if both `A` and `B` have
strided parents, and forward the summation to the parents otherwise.
This PR changes this to use broadcasting if either of the two has a
strided parent. This avoids accessing the parent corresponding to the
structural zero elements, as the index might not be initialized.

Fixes https://github.com/JuliaLang/julia/issues/55590

This isn't a general fix, as we still sum the parents if neither is
strided. However, it will address common cases.

This also improves performance, as we only need to loop over one half:
```julia
julia> using LinearAlgebra

julia> U = UpperTriangular(zeros(100,100));

julia> B = Bidiagonal(zeros(100), zeros(99), :U);

julia> @Btime $U + $B;
  35.530 μs (4 allocations: 78.22 KiB) # nightly
  13.441 μs (4 allocations: 78.22 KiB) # This PR
```
  • Loading branch information
jishnub authored Sep 19, 2024
1 parent bf0c690 commit 1aff771
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 30 deletions.
8 changes: 4 additions & 4 deletions src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -687,10 +687,10 @@ for f in (:+, :-)
@eval begin
$f(A::Hermitian, B::Symmetric{<:Real}) = $f(A, Hermitian(parent(B), sym_uplo(B.uplo)))
$f(A::Symmetric{<:Real}, B::Hermitian) = $f(Hermitian(parent(A), sym_uplo(A.uplo)), B)
$f(A::SymTridiagonal, B::Symmetric) = Symmetric($f(A, B.data), sym_uplo(B.uplo))
$f(A::Symmetric, B::SymTridiagonal) = Symmetric($f(A.data, B), sym_uplo(A.uplo))
$f(A::SymTridiagonal{<:Real}, B::Hermitian) = Hermitian($f(A, B.data), sym_uplo(B.uplo))
$f(A::Hermitian, B::SymTridiagonal{<:Real}) = Hermitian($f(A.data, B), sym_uplo(A.uplo))
$f(A::SymTridiagonal, B::Symmetric) = $f(Symmetric(A, sym_uplo(B.uplo)), B)
$f(A::Symmetric, B::SymTridiagonal) = $f(A, Symmetric(B, sym_uplo(A.uplo)))
$f(A::SymTridiagonal{<:Real}, B::Hermitian) = $f(Hermitian(A, sym_uplo(B.uplo)), B)
$f(A::Hermitian, B::SymTridiagonal{<:Real}) = $f(A, Hermitian(B, sym_uplo(A.uplo)))
end
end

Expand Down
91 changes: 65 additions & 26 deletions src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -850,35 +850,74 @@ fillstored!(A::UpperTriangular, x) = (fillband!(A.data, x, 0, size(A,2)-1);
fillstored!(A::UnitUpperTriangular, x) = (fillband!(A.data, x, 1, size(A,2)-1); A)

# Binary operations
+(A::UpperTriangular, B::UpperTriangular) = UpperTriangular(A.data + B.data)
+(A::LowerTriangular, B::LowerTriangular) = LowerTriangular(A.data + B.data)
+(A::UpperTriangular, B::UnitUpperTriangular) = UpperTriangular(A.data + triu(B.data, 1) + I)
+(A::LowerTriangular, B::UnitLowerTriangular) = LowerTriangular(A.data + tril(B.data, -1) + I)
+(A::UnitUpperTriangular, B::UpperTriangular) = UpperTriangular(triu(A.data, 1) + B.data + I)
+(A::UnitLowerTriangular, B::LowerTriangular) = LowerTriangular(tril(A.data, -1) + B.data + I)
+(A::UnitUpperTriangular, B::UnitUpperTriangular) = UpperTriangular(triu(A.data, 1) + triu(B.data, 1) + 2I)
+(A::UnitLowerTriangular, B::UnitLowerTriangular) = LowerTriangular(tril(A.data, -1) + tril(B.data, -1) + 2I)
# use broadcasting if the parents are strided, where we loop only over the triangular part
function +(A::UpperTriangular, B::UpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
UpperTriangular(A.data + B.data)
end
function +(A::LowerTriangular, B::LowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
LowerTriangular(A.data + B.data)
end
function +(A::UpperTriangular, B::UnitUpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
UpperTriangular(A.data + triu(B.data, 1) + I)
end
function +(A::LowerTriangular, B::UnitLowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
LowerTriangular(A.data + tril(B.data, -1) + I)
end
function +(A::UnitUpperTriangular, B::UpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
UpperTriangular(triu(A.data, 1) + B.data + I)
end
function +(A::UnitLowerTriangular, B::LowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
LowerTriangular(tril(A.data, -1) + B.data + I)
end
function +(A::UnitUpperTriangular, B::UnitUpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
UpperTriangular(triu(A.data, 1) + triu(B.data, 1) + 2I)
end
function +(A::UnitLowerTriangular, B::UnitLowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
LowerTriangular(tril(A.data, -1) + tril(B.data, -1) + 2I)
end
+(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A)), A) + copyto!(similar(parent(B)), B)

-(A::UpperTriangular, B::UpperTriangular) = UpperTriangular(A.data - B.data)
-(A::LowerTriangular, B::LowerTriangular) = LowerTriangular(A.data - B.data)
-(A::UpperTriangular, B::UnitUpperTriangular) = UpperTriangular(A.data - triu(B.data, 1) - I)
-(A::LowerTriangular, B::UnitLowerTriangular) = LowerTriangular(A.data - tril(B.data, -1) - I)
-(A::UnitUpperTriangular, B::UpperTriangular) = UpperTriangular(triu(A.data, 1) - B.data + I)
-(A::UnitLowerTriangular, B::LowerTriangular) = LowerTriangular(tril(A.data, -1) - B.data + I)
-(A::UnitUpperTriangular, B::UnitUpperTriangular) = UpperTriangular(triu(A.data, 1) - triu(B.data, 1))
-(A::UnitLowerTriangular, B::UnitLowerTriangular) = LowerTriangular(tril(A.data, -1) - tril(B.data, -1))
-(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A)), A) - copyto!(similar(parent(B)), B)

# use broadcasting if the parents are strided, where we loop only over the triangular part
for op in (:+, :-)
for TM1 in (:LowerTriangular, :UnitLowerTriangular), TM2 in (:LowerTriangular, :UnitLowerTriangular)
@eval $op(A::$TM1{<:Any, <:StridedMaybeAdjOrTransMat}, B::$TM2{<:Any, <:StridedMaybeAdjOrTransMat}) = broadcast($op, A, B)
end
for TM1 in (:UpperTriangular, :UnitUpperTriangular), TM2 in (:UpperTriangular, :UnitUpperTriangular)
@eval $op(A::$TM1{<:Any, <:StridedMaybeAdjOrTransMat}, B::$TM2{<:Any, <:StridedMaybeAdjOrTransMat}) = broadcast($op, A, B)
end
function -(A::UpperTriangular, B::UpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
UpperTriangular(A.data - B.data)
end
function -(A::LowerTriangular, B::LowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
LowerTriangular(A.data - B.data)
end
function -(A::UpperTriangular, B::UnitUpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
UpperTriangular(A.data - triu(B.data, 1) - I)
end
function -(A::LowerTriangular, B::UnitLowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
LowerTriangular(A.data - tril(B.data, -1) - I)
end
function -(A::UnitUpperTriangular, B::UpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
UpperTriangular(triu(A.data, 1) - B.data + I)
end
function -(A::UnitLowerTriangular, B::LowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
LowerTriangular(tril(A.data, -1) - B.data + I)
end
function -(A::UnitUpperTriangular, B::UnitUpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
UpperTriangular(triu(A.data, 1) - triu(B.data, 1))
end
function -(A::UnitLowerTriangular, B::UnitLowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
LowerTriangular(tril(A.data, -1) - tril(B.data, -1))
end
-(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A)), A) - copyto!(similar(parent(B)), B)

function kron(A::UpperTriangular{<:Number,<:StridedMaybeAdjOrTransMat}, B::UpperTriangular{<:Number,<:StridedMaybeAdjOrTransMat})
C = UpperTriangular(Matrix{promote_op(*, eltype(A), eltype(B))}(undef, _kronsize(A, B)))
Expand Down
25 changes: 25 additions & 0 deletions test/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1135,4 +1135,29 @@ end
end
end

@testset "partly iniitalized matrices" begin
a = Matrix{BigFloat}(undef, 2,2)
a[1] = 1; a[3] = 1; a[4] = 1
h = Hermitian(a)
s = Symmetric(a)
d = Diagonal([1,1])
symT = SymTridiagonal([1 1;1 1])
@test h+d == Array(h) + Array(d)
@test h+symT == Array(h) + Array(symT)
@test s+d == Array(s) + Array(d)
@test s+symT == Array(s) + Array(symT)
@test h-d == Array(h) - Array(d)
@test h-symT == Array(h) - Array(symT)
@test s-d == Array(s) - Array(d)
@test s-symT == Array(s) - Array(symT)
@test d+h == Array(d) + Array(h)
@test symT+h == Array(symT) + Array(h)
@test d+s == Array(d) + Array(s)
@test symT+s == Array(symT) + Array(s)
@test d-h == Array(d) - Array(h)
@test symT-h == Array(symT) - Array(h)
@test d-s == Array(d) - Array(s)
@test symT-s == Array(symT) - Array(s)
end

end # module TestSymmetric

0 comments on commit 1aff771

Please sign in to comment.