Skip to content

Commit

Permalink
fix #31674, error when storing nonzeros into structural zeros with .= (
Browse files Browse the repository at this point in the history
…#31678)

Previously, broadcasted assignment (`.=`) would happily ignore all nonstructured portions of the destination, regardless of whether the broadcasted expression would actually evaluate to zero or not. This changes these in-place methods to use the same infrastructure that out-of-place broadcast uses to determine the result type. If we are unsure of the structural properties of the output, we fall back to the generic implementation, which will attempt to store into every single location of the destination -- including those structural zeros. Thus we now error in cases where we generate nonzeros in those locations.

(cherry picked from commit 6bd3967)
  • Loading branch information
mbauman authored and KristofferC committed May 20, 2019
1 parent f254c2e commit 861f7a2
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
10 changes: 9 additions & 1 deletion stdlib/LinearAlgebra/src/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ function Base.similar(bc::Broadcasted{StructuredMatrixStyle{T}}, ::Type{ElType})
end

function copyto!(dest::Diagonal, bc::Broadcasted{<:StructuredMatrixStyle})
!isstructurepreserving(bc) && !fzeropreserving(bc) && return copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for i in axs[1]
Expand All @@ -111,6 +112,7 @@ function copyto!(dest::Diagonal, bc::Broadcasted{<:StructuredMatrixStyle})
end

function copyto!(dest::Bidiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
!isstructurepreserving(bc) && !fzeropreserving(bc) && return copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for i in axs[1]
Expand All @@ -129,18 +131,22 @@ function copyto!(dest::Bidiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
end

function copyto!(dest::SymTridiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
!isstructurepreserving(bc) && !fzeropreserving(bc) && return copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for i in axs[1]
dest.dv[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i))
end
for i = 1:size(dest, 1)-1
dest.ev[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1))
v = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1))
v == Broadcast._broadcast_getindex(bc, CartesianIndex(i+1, i)) || throw(ArgumentError("broadcasted assignment breaks symmetry between locations ($i, $(i+1)) and ($(i+1), $i)"))
dest.ev[i] = v
end
return dest
end

function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
!isstructurepreserving(bc) && !fzeropreserving(bc) && return copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for i in axs[1]
Expand All @@ -154,6 +160,7 @@ function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle})
end

function copyto!(dest::LowerTriangular, bc::Broadcasted{<:StructuredMatrixStyle})
!isstructurepreserving(bc) && !fzeropreserving(bc) && return copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for j in axs[2]
Expand All @@ -165,6 +172,7 @@ function copyto!(dest::LowerTriangular, bc::Broadcasted{<:StructuredMatrixStyle}
end

function copyto!(dest::UpperTriangular, bc::Broadcasted{<:StructuredMatrixStyle})
!isstructurepreserving(bc) && !fzeropreserving(bc) && return copyto!(dest, convert(Broadcasted{Nothing}, bc))
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for j in axs[2]
Expand Down
29 changes: 26 additions & 3 deletions stdlib/LinearAlgebra/test/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,37 @@ end
A = rand(N, N)
sA = A + copy(A')
D = Diagonal(rand(N))
B = Bidiagonal(rand(N), rand(N - 1), :U)
Bu = Bidiagonal(rand(N), rand(N - 1), :U)
Bl = Bidiagonal(rand(N), rand(N - 1), :L)
T = Tridiagonal(rand(N - 1), rand(N), rand(N - 1))
= LowerTriangular(rand(N,N))
= UpperTriangular(rand(N,N))

@test broadcast!(sin, copy(D), D) == Diagonal(sin.(D))
@test broadcast!(sin, copy(B), B) == Bidiagonal(sin.(B), :U)
@test broadcast!(sin, copy(Bu), Bu) == Bidiagonal(sin.(Bu), :U)
@test broadcast!(sin, copy(Bl), Bl) == Bidiagonal(sin.(Bl), :L)
@test broadcast!(sin, copy(T), T) == Tridiagonal(sin.(T))
@test broadcast!(sin, copy(◣), ◣) == LowerTriangular(sin.(◣))
@test broadcast!(sin, copy(◥), ◥) == UpperTriangular(sin.(◥))
@test broadcast!(*, copy(D), D, A) == Diagonal(broadcast(*, D, A))
@test broadcast!(*, copy(B), B, A) == Bidiagonal(broadcast(*, B, A), :U)
@test broadcast!(*, copy(Bu), Bu, A) == Bidiagonal(broadcast(*, Bu, A), :U)
@test broadcast!(*, copy(Bl), Bl, A) == Bidiagonal(broadcast(*, Bl, A), :L)
@test broadcast!(*, copy(T), T, A) == Tridiagonal(broadcast(*, T, A))
@test broadcast!(*, copy(◣), ◣, A) == LowerTriangular(broadcast(*, ◣, A))
@test broadcast!(*, copy(◥), ◥, A) == UpperTriangular(broadcast(*, ◥, A))

@test_throws ArgumentError broadcast!(cos, copy(D), D) == Diagonal(sin.(D))
@test_throws ArgumentError broadcast!(cos, copy(Bu), Bu) == Bidiagonal(sin.(Bu), :U)
@test_throws ArgumentError broadcast!(cos, copy(Bl), Bl) == Bidiagonal(sin.(Bl), :L)
@test_throws ArgumentError broadcast!(cos, copy(T), T) == Tridiagonal(sin.(T))
@test_throws ArgumentError broadcast!(cos, copy(◣), ◣) == LowerTriangular(sin.(◣))
@test_throws ArgumentError broadcast!(cos, copy(◥), ◥) == UpperTriangular(sin.(◥))
@test_throws ArgumentError broadcast!(+, copy(D), D, A) == Diagonal(broadcast(*, D, A))
@test_throws ArgumentError broadcast!(+, copy(Bu), Bu, A) == Bidiagonal(broadcast(*, Bu, A), :U)
@test_throws ArgumentError broadcast!(+, copy(Bl), Bl, A) == Bidiagonal(broadcast(*, Bl, A), :L)
@test_throws ArgumentError broadcast!(+, copy(T), T, A) == Tridiagonal(broadcast(*, T, A))
@test_throws ArgumentError broadcast!(+, copy(◣), ◣, A) == LowerTriangular(broadcast(*, ◣, A))
@test_throws ArgumentError broadcast!(+, copy(◥), ◥, A) == UpperTriangular(broadcast(*, ◥, A))
end

@testset "map[!] over combinations of structured matrices" begin
Expand Down

0 comments on commit 861f7a2

Please sign in to comment.