Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Current rules for sqrt produce NaN for zero primal and (co)tangents #576

Open
sethaxen opened this issue Jan 19, 2022 · 2 comments · May be fixed by #599
Open

Current rules for sqrt produce NaN for zero primal and (co)tangents #576

sethaxen opened this issue Jan 19, 2022 · 2 comments · May be fixed by #599
Labels
bug Something isn't working

Comments

@sethaxen
Copy link
Member

This only happens when the (co)tangent is 0.

julia> using ChainRules

julia> ChainRules.frule((ChainRules.ZeroTangent(), 0.0), sqrt, 0.0)
(0.0, NaN)

julia> ChainRules.rrule(sqrt, 0.0)[2](0.0)
(ChainRulesCore.NoTangent(), NaN)

I suggest we adopt the convention that the produced (co)tangent in this case should also be 0. This is supported by finite differerences:

julia> using FiniteDifferences

julia> jvp(central_fdm(5, 1), sqrt, (0.0, 0.0))
0.0

julia> j′vp(central_fdm(5, 1), x -> sqrt(clamp(x, 0, Inf)), 0.0, 0.0)
(0.0,)

julia> j′vp(central_fdm(5, 1), sqrt  abs, 0.0, 0.0)
(0.0,)

So instead of using @scalar_rule we would explicitly define the frule and rrule.

@oxinabox oxinabox added the bug Something isn't working label Jan 19, 2022
@mcabbott
Copy link
Member

mcabbott commented Jan 19, 2022

So the proposal is to always treat zero tangent (or cotangent) as a strong zero. I think that makes sense.

The rule for x^p already treats Δp being zero strongly, here:
https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/fastmath_able.jl#L171
The logic was that quite often the p is really a constant, and its zero would then turn some otherwise correct infinite derivatives with respect to x into NaNs.

x^0.5 behaves just like sqrt(x) with respect to Δx:

julia> ChainRules.frule((ChainRules.ZeroTangent(), 0.01, 0.0), ^, 0.0, 0.5)
(0.0, Inf)

julia> ChainRules.frule((ChainRules.ZeroTangent(), 0.0, 0.0), ^, 0.0, 0.5)
(0.0, NaN)

julia> ChainRules.rrule(^, 0.0, 0.5)[2](0.01)[2]
Inf

julia> ChainRules.rrule(^, 0.0, 0.5)[2](0.0)[2]
NaN

Treating zero tangent as strong could be done globally in @scalar_rule. Which would mean one more ifelse in all rules, I wonder if that's expensive.

Edit: as David points out here: JuliaDiff/ForwardDiff.jl#547 (comment) this is something very close to re-inventing ForwardDiff's nan-safe mode. That had some speed penalty; I also see more branches than seem necessary.

Edit': This also affects anything using derivatives_given_output. Which is still marked experimental. I guess it might be a reason to re-think it; perhaps the multiplication ought to happen inside the function, so that (for functions like sqrt) this can do careful things.

@sethaxen
Copy link
Member Author

sethaxen commented Mar 5, 2022

Since this particular issue in multiple ADs has been noticed by several users over the last few months, it would be good to push a fix soon. While it might take more discussion (and benchmarking) to decide on an equivalent to a NaN-safe mode for ChainRules, I think we're in agreement that we need something like this specifically for sqrt, and we can do this now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants