Skip to content

Commit

Permalink
Fix ReverseDiff bug
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Jul 11, 2023
1 parent 3987311 commit 127bc0a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 22 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FillArrays"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "1.4.0"
version = "1.4.1"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
28 changes: 7 additions & 21 deletions src/fillbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,20 +195,18 @@ function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange, b::O
return _range_convert(AbstractVector{TT}, a)
end

# To fix AD issues with `broadcast(T, x)`
# Avoids type inference issues with x -> T(x)
struct Constructor{T} end

function (::Constructor{T})(x) where {T}
return T(x)
end

for op in (:+, :-)
@eval begin
function broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractVector, b::ZerosVector)
broadcast_shape(axes(a), axes(b)) == axes(a) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $b to a Vector first."))
TT = typeof($op(zero(eltype(a)), zero(eltype(b))))
eltype(a) === TT ? a : broadcasted(Constructor{TT}(), a)
# Use `TT ∘ (+)` to fix AD issues with `broadcasted(TT, x)`
eltype(a) === TT ? a : broadcasted(TT (+), a)
end
function broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::ZerosVector, b::AbstractVector)
broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $a to a Vector first."))
TT = typeof($op(zero(eltype(a)), zero(eltype(b))))
$op === (+) && eltype(b) === TT ? b : broadcasted(TT ($op), b)
end

broadcasted(::DefaultArrayStyle{1}, ::typeof($op), a::AbstractFillVector, b::ZerosVector) =
Expand All @@ -219,18 +217,6 @@ for op in (:+, :-)
end
end

function broadcasted(::DefaultArrayStyle{1}, ::typeof(+), a::ZerosVector, b::AbstractVector)
broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $a to a Vector first."))
TT = typeof(zero(eltype(a)) + zero(eltype(b)))
eltype(b) === TT ? b : broadcasted(Constructor{TT}(), b)
end

function broadcasted(::DefaultArrayStyle{1}, ::typeof(-), a::ZerosVector, b::AbstractVector)
broadcast_shape(axes(a), axes(b)) == axes(b) || throw(ArgumentError("Cannot broadcast $a and $b. Convert $a to a Vector first."))
TT = typeof(zero(eltype(a)) - zero(eltype(b)))
broadcasted(TT (-), b)
end

# Need to prevent array-valued fills from broadcasting over entry
_broadcast_getindex_value(a::AbstractFill{<:Number}) = getindex_value(a)
_broadcast_getindex_value(a::AbstractFill) = Ref(getindex_value(a))
Expand Down

0 comments on commit 127bc0a

Please sign in to comment.