-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Current rules for sqrt produce NaN for zero primal and (co)tangents #576
Comments
So the proposal is to always treat zero tangent (or cotangent) as a strong zero. I think that makes sense. The rule for
julia> ChainRules.frule((ChainRules.ZeroTangent(), 0.01, 0.0), ^, 0.0, 0.5)
(0.0, Inf)
julia> ChainRules.frule((ChainRules.ZeroTangent(), 0.0, 0.0), ^, 0.0, 0.5)
(0.0, NaN)
julia> ChainRules.rrule(^, 0.0, 0.5)[2](0.01)[2]
Inf
julia> ChainRules.rrule(^, 0.0, 0.5)[2](0.0)[2]
NaN Treating zero tangent as strong could be done globally in Edit: as David points out here: JuliaDiff/ForwardDiff.jl#547 (comment) this is something very close to re-inventing ForwardDiff's nan-safe mode. That had some speed penalty; I also see more branches than seem necessary. Edit': This also affects anything using |
Since this particular issue in multiple ADs has been noticed by several users over the last few months, it would be good to push a fix soon. While it might take more discussion (and benchmarking) to decide on an equivalent to a NaN-safe mode for ChainRules, I think we're in agreement that we need something like this specifically for |
This only happens when the (co)tangent is 0.
I suggest we adopt the convention that the produced (co)tangent in this case should also be 0. This is supported by finite differerences:
So instead of using
@scalar_rule
we would explicitly define thefrule
andrrule
.The text was updated successfully, but these errors were encountered: