Skip to content

Commit

Permalink
Merge #1089
Browse files Browse the repository at this point in the history
1089: Fix broadcasting `-` with booleans  r=mcabbott a=mcabbott

Closes #1086

The issue is that `unbroadcast` collapses the gradient of a non-differentiable argument to `nothing`, Zygote's marker for such things, which the rule then tried to negate. (This is why CRC defines its own zero types for this purpose, instead of using nothing). I presume the original reason for `-unbroadcast(y, Δ)` not `unbroadcast(y, -Δ)` is to save an allocation when `y` is a scalar etc.

I did not see any other cases where further operations are done after `unbroadcast`. 

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
  • Loading branch information
bors[bot] and mcabbott authored Oct 2, 2021
2 parents ade9b1e + 4f7d5d1 commit 9fdc055
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ unbroadcast(x::AbstractArray, x̄::Nothing) = nothing
broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...)

@adjoint broadcasted(::typeof(-), x::Numeric, y::Numeric) = x .- y,
Δ -> (nothing, unbroadcast(x, Δ), -unbroadcast(y, Δ))
Δ -> (nothing, unbroadcast(x, Δ), _minus(unbroadcast(y, Δ)))
_minus(Δ) = -Δ
_minus(::Nothing) = nothing

@adjoint broadcasted(::typeof(*), x::Numeric, y::Numeric) = x.*y,
Δ -> (nothing, unbroadcast(x, Δ .* conj.(y)), unbroadcast(y, Δ .* conj.(x)))
Expand Down
5 changes: 5 additions & 0 deletions test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,11 @@ end
@test gradient(x -> sum(_f.(x)), [1,2,3]) == ([0.5, 0.5, 0.5],)
@test gradient(x -> sum(map(_f, x)), [1,2,3]) == ([0.5, 0.5, 0.5],)

# with Bool
@test gradient(x -> sum(1 .- (x .> 0)), randn(5)) == (nothing,)
@test gradient(x -> sum((y->1-y).(x .> 0)), randn(5)) == (nothing,)
@test gradient(x -> sum(x .- (x .> 0)), randn(5)) == ([1,1,1,1,1],)

@test gradient(x -> sum(x ./ [1,2,4]), [1,2,pi]) == ([1.0, 0.5, 0.25],)
@test gradient(x -> sum(map(/, x, [1,2,4])), [1,2,pi]) == ([1.0, 0.5, 0.25],)

Expand Down

0 comments on commit 9fdc055

Please sign in to comment.