Skip to content

Current rules for sqrt produce NaN for zero primal and (co)tangents #576

Open
@sethaxen

Description

@sethaxen

This only happens when the (co)tangent is 0.

julia> using ChainRules

julia> ChainRules.frule((ChainRules.ZeroTangent(), 0.0), sqrt, 0.0)
(0.0, NaN)

julia> ChainRules.rrule(sqrt, 0.0)[2](0.0)
(ChainRulesCore.NoTangent(), NaN)

I suggest we adopt the convention that the produced (co)tangent in this case should also be 0. This is supported by finite differerences:

julia> using FiniteDifferences

julia> jvp(central_fdm(5, 1), sqrt, (0.0, 0.0))
0.0

julia> j′vp(central_fdm(5, 1), x -> sqrt(clamp(x, 0, Inf)), 0.0, 0.0)
(0.0,)

julia> j′vp(central_fdm(5, 1), sqrt  abs, 0.0, 0.0)
(0.0,)

So instead of using @scalar_rule we would explicitly define the frule and rrule.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions