Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ReverseDiff bug #278

Merged
merged 3 commits into from
Jul 17, 2023
Merged

Conversation

devmotion
Copy link
Contributor

@devmotion devmotion commented Jul 11, 2023

Unfortunately, the last commit in #273 (which fixed type inference issues) broke a ReverseDiff example (slightly modified version of the MWE in #252). On the master branch:

julia> using ReverseDiff, FillArrays

julia> ReverseDiff.gradient(x -> sum(abs2.((zeros(5) .- Zeros{eltype(x)}(5)) ./ x)), rand(5))
ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{ReverseDiff.var"#109#111"{Base.Broadcast.var"#12#14"{Base.Broadcast.var"#18#20"{Base.Broadcast.var"#11#13", Base.Broadcast.var"#18#20"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}, Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}, FillArrays.Constructor{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}}, typeof(/)}, typeof(abs2)}, Tuple{}, Val{(1, 2)}}, Float64}, Float64, 2})

Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
   @ Base rounding.jl:207
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:792
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number}
   @ Base char.jl:50
  ...

Stacktrace:
  [1] convert(#unused#::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{ReverseDiff.var"#109#111"{Base.Broadcast.var"#12#14"{Base.Broadcast.var"#18#20"{Base.Broadcast.var"#11#13", Base.Broadcast.var"#18#20"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}, Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}, FillArrays.Constructor{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}}, typeof(/)}, typeof(abs2)}, Tuple{}, Val{(1, 2)}}, Float64}, Float64, 2})
    @ Base ./number.jl:7
  [2] ReverseDiff.TrackedReal{Float64, Float64, Nothing}(value::ForwardDiff.Dual{ForwardDiff.Tag{ReverseDiff.var"#109#111"{Base.Broadcast.var"#12#14"{Base.Broadcast.var"#18#20"{Base.Broadcast.var"#11#13", Base.Broadcast.var"#18#20"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}, Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}, FillArrays.Constructor{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}}, typeof(/)}, typeof(abs2)}, Tuple{}, Val{(1, 2)}}, Float64}, Float64, 2})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/wIfrd/src/tracked.jl:56
  [3] (::FillArrays.Constructor{ReverseDiff.TrackedReal{Float64, Float64, Nothing}})(x::ForwardDiff.Dual{ForwardDiff.Tag{ReverseDiff.var"#109#111"{Base.Broadcast.var"#12#14"{Base.Broadcast.var"#18#20"{Base.Broadcast.var"#11#13", Base.Broadcast.var"#18#20"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}, Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}, FillArrays.Constructor{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}}, typeof(/)}, typeof(abs2)}, Tuple{}, Val{(1, 2)}}, Float64}, Float64, 2})
    @ FillArrays ~/.julia/packages/FillArrays/eFtCC/src/fillbroadcast.jl:203
  [4] #18
    @ ./broadcast.jl:396 [inlined]
  [5] #18
    @ ./broadcast.jl:394 [inlined]
  [6] #12
    @ ./broadcast.jl:342 [inlined]
  [7] macro expansion
    @ ~/.julia/packages/ReverseDiff/wIfrd/src/derivatives/broadcast.jl:126 [inlined]
  [8] splatcall
    @ ~/.julia/packages/ReverseDiff/wIfrd/src/derivatives/broadcast.jl:111 [inlined]
  [9] (::ReverseDiff.var"#109#111"{Base.Broadcast.var"#12#14"{Base.Broadcast.var"#18#20"{Base.Broadcast.var"#11#13", Base.Broadcast.var"#18#20"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}, Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}, FillArrays.Constructor{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}}, typeof(/)}, typeof(abs2)}, Tuple{}, Val{(1, 2)}})(s::StaticArraysCore.SVector{2, ForwardDiff.Dual{ForwardDiff.Tag{ReverseDiff.var"#109#111"{Base.Broadcast.var"#12#14"{Base.Broadcast.var"#18#20"{Base.Broadcast.var"#11#13", Base.Broadcast.var"#18#20"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}, Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}, FillArrays.Constructor{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}}, typeof(/)}, typeof(abs2)}, Tuple{}, Val{(1, 2)}}, Float64}, Float64, 2}})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/wIfrd/src/derivatives/broadcast.jl:150
 [10] static_dual_eval
    @ ~/.julia/packages/ForwardDiff/vXysl/ext/ForwardDiffStaticArraysExt.jl:24 [inlined]
 [11] vector_mode_gradient!(result::DiffResults.ImmutableDiffResult{1, Float64, Tuple{StaticArraysCore.SVector{2, Float64}}}, f::Function, x::StaticArraysCore.SVector{2, Float64})
    @ ForwardDiffStaticArraysExt ~/.julia/packages/ForwardDiff/vXysl/ext/ForwardDiffStaticArraysExt.jl:64
 [12] gradient!(result::DiffResults.ImmutableDiffResult{1, Float64, Tuple{StaticArraysCore.SVector{2, Float64}}}, f::Function, x::StaticArraysCore.SVector{2, Float64})
    @ ForwardDiffStaticArraysExt ~/.julia/packages/ForwardDiff/vXysl/ext/ForwardDiffStaticArraysExt.jl:44
 [13] (::ReverseDiff.var"#df#110"{Base.Broadcast.var"#12#14"{Base.Broadcast.var"#18#20"{Base.Broadcast.var"#11#13", Base.Broadcast.var"#18#20"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}, Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}, FillArrays.Constructor{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}}, typeof(/)}, typeof(abs2)}, DiffResults.ImmutableDiffResult{1, Float64, Tuple{StaticArraysCore.SVector{2, Float64}}}, Tuple{}, Val{(1, 2)}})(::Float64, ::Vararg{Float64})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/wIfrd/src/derivatives/broadcast.jl:148
 [14] _broadcast_getindex_evalf
    @ ./broadcast.jl:683 [inlined]
 [15] _broadcast_getindex
    @ ./broadcast.jl:656 [inlined]
 [16] getindex
    @ ./broadcast.jl:610 [inlined]
 [17] copy
    @ ./broadcast.jl:912 [inlined]
 [18] materialize
    @ ./broadcast.jl:873 [inlined]
 [19] broadcast(::ReverseDiff.var"#df#110"{Base.Broadcast.var"#12#14"{Base.Broadcast.var"#18#20"{Base.Broadcast.var"#11#13", Base.Broadcast.var"#18#20"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}, Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}, FillArrays.Constructor{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}}, typeof(/)}, typeof(abs2)}, DiffResults.ImmutableDiffResult{1, Float64, Tuple{StaticArraysCore.SVector{2, Float64}}}, Tuple{}, Val{(1, 2)}}, ::Vector{Float64}, ::Vector{Float64})
    @ Base.Broadcast ./broadcast.jl:811
 [20] ∇broadcast(::Base.Broadcast.var"#12#14"{Base.Broadcast.var"#18#20"{Base.Broadcast.var"#11#13", Base.Broadcast.var"#18#20"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}, Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}, FillArrays.Constructor{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}}, typeof(/)}, typeof(abs2)}, ::Vector{Float64}, ::Vararg{Any})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/wIfrd/src/derivatives/broadcast.jl:154
 [21] copy(_bc::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle, Tuple{Base.OneTo{Int64}}, typeof(abs2), Tuple{Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle, Nothing, typeof(/), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, FillArrays.Constructor{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, Tuple{Vector{Float64}}}, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}}})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/wIfrd/src/derivatives/broadcast.jl:94
 [22] materialize
    @ ./broadcast.jl:873 [inlined]
 [23] (::var"#5#6")(x::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}})
    @ Main ./REPL[5]:1
 [24] ReverseDiff.GradientTape(f::var"#5#6", input::Vector{Float64}, cfg::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/wIfrd/src/api/tape.jl:199
 [25] gradient(f::Function, input::Vector{Float64}, cfg::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/wIfrd/src/api/gradients.jl:22
 [26] gradient(f::Function, input::Vector{Float64})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/wIfrd/src/api/gradients.jl:22
 [27] top-level scope
    @ REPL[5]:1

Put simply, the general problem with broadcasting and AD is that most backends fall back to compute derivatives of the applied function with ForwardDiff - but the type of the output type is fixed to one that is incompatible with the internally used Dual numbers. This PR seems to fix the example above in an arguably simple way.

In an ideal world, FillArrays would not have to care about these AD-specific details. AD packages should be improved and their broadcasting should be more stable. But I don't see this happening anytime soon given the number of contributors to packages such as ReverseDiff or Tracker, and the huge number of possibilities of breaking downstream code that is working fine right now. I'm not a maintainer of FillArrays but since all these FillArray-specific issues did not exist prior to FillArrays 1.0.1, it seems reasonable to me to apply these AD fixes directly in FillArrays.

Unfortunately, I don't see a good way of testing these ReverseDiff issues without adding AD tests or downstream tests on DistributionsAD.

I'll rerun the DistributionsAD tests with this PR in the DistributionsAD repo: TuringLang/DistributionsAD.jl#250

@codecov
Copy link

codecov bot commented Jul 11, 2023

Codecov Report

Merging #278 (08f23a0) into master (8d4d190) will decrease coverage by 4.36%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master     #278      +/-   ##
==========================================
- Coverage   99.88%   95.52%   -4.36%     
==========================================
  Files           5        5              
  Lines         857      849       -8     
==========================================
- Hits          856      811      -45     
- Misses          1       38      +37     
Impacted Files Coverage Δ
src/FillArrays.jl 96.51% <100.00%> (-3.22%) ⬇️
src/fillbroadcast.jl 97.65% <100.00%> (-2.35%) ⬇️

... and 3 files with indirect coverage changes

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@devmotion
Copy link
Contributor Author

The InfiniteLinearAlgebra downstream test failure seems unrelated? https://github.com/JuliaArrays/FillArrays.jl/actions/runs/5525483793/jobs/10079152899?pr=278

@jishnub
Copy link
Member

jishnub commented Jul 12, 2023

Yes, that should be fixed by JuliaArrays/LazyArrays.jl#257

@devmotion
Copy link
Contributor Author

What's your preference regarding tests? Should we add DistributionsAD (possibly only the ReverseDiff tests) to the downstream tests? Or would you be OK with adding the ReverseDiff examples in this PR and #273 to the FillArrays tests?

@devmotion devmotion closed this Jul 13, 2023
@devmotion devmotion reopened this Jul 13, 2023
@dlfivefifty dlfivefifty requested a review from jishnub July 14, 2023 09:08
@dlfivefifty
Copy link
Member

@jishnub is there any issue with adding ReverseDiff as a test dependency? I checked and it doesn't depend on FillArrays.jl so I think its fine

@jishnub
Copy link
Member

jishnub commented Jul 14, 2023

There's perhaps no fundamental problem with ReverseDiff. I had refrained from adding heavy packages, but we may make an exception for this, given that it's fairly fundamental. However, are we considering the ReverseDiff test suite, or the DistributinoAD tests that use ReverseDiff?

It would be ideal if we could test a small subsection that depends on FillArrays, although I realize that this might be difficult.

@devmotion
Copy link
Contributor Author

So far I've only added the two MWE to the tests, neither the ReverseDiff nor the DistributionsAD-ReverseDiff (there exist tests for Zygote, Tracker, and ForwardDiff as well...) downstream tests are run. The main reason is that the former takes only a few seconds whereas both downstream tests would take around 45-60 minutes (and would also install more dependencies, I assume).

Breaking down the DistributionsAD tests further, not only based on the AD backend but also the use of FillArrays, seems challenging since it is used a lot in the multivariate normal distributions but also in some other distributions, sometimes even only internally in Distributions.

@jishnub
Copy link
Member

jishnub commented Jul 17, 2023

I think the present approach in this PR makes sense. Running the entire test suite of a large package might increase the maintenance burden of this package in case there are failures originating elsewhere.

@jishnub
Copy link
Member

jishnub commented Jul 17, 2023

Could you also add some AD tests for +, along with the tests for - that are added in this PR?

@devmotion
Copy link
Contributor Author

Done!

@dlfivefifty dlfivefifty merged commit be9386c into JuliaArrays:master Jul 17, 2023
21 of 22 checks passed
@devmotion devmotion deleted the dw/fix_broadcast_try2 branch July 17, 2023 13:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants