Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failure to handle nothing output from Zygote.jl #662

Closed
MilesCranmer opened this issue Dec 8, 2024 · 8 comments · Fixed by #667
Closed

Failure to handle nothing output from Zygote.jl #662

MilesCranmer opened this issue Dec 8, 2024 · 8 comments · Fixed by #667

Comments

@MilesCranmer
Copy link

MilesCranmer commented Dec 8, 2024

julia> using Zygote, DifferentiationInterface

julia> safe_log(x) = x > zero(x) ? log(x) : convert(typeof(x), NaN)
safe_log (generic function with 1 method)

julia> derivative(safe_log, AutoZygote(), 0.0)
ERROR: MethodError: no method matching iterate(::Nothing)
The function `iterate` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  iterate(::Combinatorics.Combinations)
   @ Combinatorics ~/.julia/packages/Combinatorics/Udg6X/src/combinations.jl:13
  iterate(::Combinatorics.Combinations, ::Any)
   @ Combinatorics ~/.julia/packages/Combinatorics/Udg6X/src/combinations.jl:13
  iterate(::Tables.DictRowTable)
   @ Tables ~/.julia/packages/Tables/8p03y/src/dicts.jl:122
  ...

Stacktrace:
 [1] dot(x::Float64, y::Nothing)
   @ LinearAlgebra ~/.julia/juliaup/julia-1.11.2+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/LinearAlgebra/src/generic.jl:855
 [2] 
   @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/gjT8p/src/first_order/pushforward.jl:161
 [3] #6
   @ ~/.julia/packages/DifferentiationInterface/gjT8p/src/first_order/pushforward.jl:192 [inlined]
 [4] ntuple
   @ ./ntuple.jl:48 [inlined]
 [5] value_and_pushforward
   @ ~/.julia/packages/DifferentiationInterface/gjT8p/src/first_order/pushforward.jl:191 [inlined]
 [6] pushforward
   @ ~/.julia/packages/DifferentiationInterface/gjT8p/src/first_order/pushforward.jl:220 [inlined]
 [7] derivative
   @ ~/.julia/packages/DifferentiationInterface/gjT8p/src/first_order/derivative.jl:124 [inlined]
 [8] derivative(::typeof(safe_log), ::AutoZygote, ::Float64)
   @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/gjT8p/src/fallbacks/no_prep.jl:49
 [9] top-level scope
   @ REPL[13]:1
Some type information was truncated. Use `show(err)` to see complete types.

Zygote.jl returns a nothing given an invalid input, so I think this should be correctly handled. Returning a NaN seems like an option?

For AutoForwardDiff, this gives us:

julia> derivative(safe_log, AutoForwardDiff(), 0.0)
0.0
@gdalle
Copy link
Member

gdalle commented Dec 30, 2024

Hi @MilesCranmer, sorry for the delay. I am aware of the issue, it is more or less a duplicate of #604.
The trouble with replacing nothing with something like NaN is that it would substitute a non-existing rule for an existing one. If the function happened to fix NaNs later on, Zygote's failure to differentiate part of the function would be completely hidden. Therefore, I don't think we should get rid of these nothings, because they denote a missing or faulty rule which needs to be corrected and not ignored.
However, I agree that we should add a better error message. Essentially what happens here is that I try to do y, pb = Zygote.pullback(f, x), which fails for y, pb = nothing. Here I could throw a custom ZygoteNothingError for example?

@MilesCranmer
Copy link
Author

Yeah that sounds good

@MilesCranmer
Copy link
Author

MilesCranmer commented Jan 1, 2025

Actually maybe I could suggest another workaround – could we have an opt-in setting in the AutoZygote backend that does replace nothing with NaN or 0.0? For my use-case (SymbolicRegression.jl) this is basically the only way I can use it via the DifferentiationInterface API. My current workaround is to handle the nothings explicitly: https://github.com/SymbolicML/DynamicExpressions.jl/blob/15849184f38636631a7a5aee7a04c02bcb7dde33/ext/DynamicExpressionsZygoteExt.jl#L6-L30

@MilesCranmer
Copy link
Author

MilesCranmer commented Jan 1, 2025

Actually a 0.0 output would be preferred here – maybe the setting could be like AutoZygote(replace_nothing_with=(x, dx) -> dx) (the default; just leaves nothing as-is) and could also be set to

AugoZygote(replace_nothing_with=(x, dx) -> zero(x))

or

AutoZygote(replace_nothing_with=(x, dx) -> convert(typeof(x), NaN)

so then I have full awareness of this being a potential footgun

@gdalle
Copy link
Member

gdalle commented Jan 1, 2025

I get where you're coming from but I'm very skeptical about this. A nothing with Zygote denotes a missing rule, and it's not DI's place to go around and tinker with the output of missing rules. Otherwise we'd have to offer the same for every backend: "your function doesn't support dual numbers, do you want me to output NaN instead?". The proper fix for a missing rule will always be the AD backend itself and not DI.

@MilesCranmer
Copy link
Author

Wait so does ForwardDiff output 0.0 for missing rules?

@gdalle
Copy link
Member

gdalle commented Jan 1, 2025

No, that's my point. Missing rules error for ForwardDiff, for Enzyme too, I don't see why we should introduce a special shortcut to zero them out for Zygote only

@MilesCranmer
Copy link
Author

Oh I see. I thought ForwardDiff was outputting 0.0 for missing rules. I'm surprised that it covers so many functions!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants