Skip to content

Commit

Permalink
Use only safe axis types for Broadcast.combine_axes
Browse files Browse the repository at this point in the history
#30074 used the wrong notion of consistency since `OneTo(1)` is
consistent (wrt broadcasting) with any range, but `OneTo` cannot
handle `Slice(-1:1)`.

(cherry picked from commit 1884cb4)
  • Loading branch information
timholy authored and KristofferC committed Dec 30, 2018
1 parent 5b7e8d9 commit 160054e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
11 changes: 4 additions & 7 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -436,19 +436,16 @@ end
_bcs1(a::Integer, b::Integer) = a == 1 ? b : (b == 1 ? a : (a == b ? a : throw(DimensionMismatch("arrays could not be broadcast to a common size"))))
_bcs1(a::Integer, b) = a == 1 ? b : (first(b) == 1 && last(b) == a ? b : throw(DimensionMismatch("arrays could not be broadcast to a common size")))
_bcs1(a, b::Integer) = _bcs1(b, a)
_bcs1(a, b) = _bcsm(b, a) ? _sametype(b, a) : (_bcsm(a, b) ? _sametype(a, b) : throw(DimensionMismatch("arrays could not be broadcast to a common size")))
_bcs1(a, b) = _bcsm(b, a) ? axistype(b, a) : (_bcsm(a, b) ? axistype(a, b) : throw(DimensionMismatch("arrays could not be broadcast to a common size")))
# _bcsm tests whether the second index is consistent with the first
_bcsm(a, b) = a == b || length(b) == 1
_bcsm(a, b::Number) = b == 1
_bcsm(a::Number, b::Number) = a == b || b == 1
# Ensure inferrability when dealing with axes of different AbstractUnitRange types
# (We may not want to define general promotion rules between, say, OneTo and Slice, but if
# we get here we know the axes are at least consistent)
_sametype(a::T, b::T) where T = a
_sametype(a::OneTo, b::OneTo) = OneTo{Int}(a)
_sametype(a::OneTo, b) = OneTo{Int}(a)
_sametype(a, b::OneTo) = OneTo{Int}(a)
_sametype(a, b) = UnitRange{Int}(a)
# we get here we know the axes are at least consistent for the purposes of broadcasting)
axistype(a::T, b::T) where T = a
axistype(a, b) = UnitRange{Int}(a)

## Check that all arguments are broadcast compatible with shape
# comparing one input against a shape
Expand Down
4 changes: 4 additions & 0 deletions test/offsetarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,10 @@ A = OffsetArray(view(rand(4,4), 1:4, 4:-1:1), (-3,5))
a = [1]
b = OffsetArray(a, (0,))
@test @inferred(a .+ b) == [2]
a = OffsetArray([1, -2, 1], (-2,))
@test a .* a' == OffsetArray([ 1 -2 1;
-2 4 -2;
1 -2 1], (-2,-2))

end # let

Expand Down

0 comments on commit 160054e

Please sign in to comment.