-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add differential_type
, to test for non-differentiability
#528
base: main
Are you sure you want to change the base?
Conversation
This comment has been minimized.
This comment has been minimized.
🚲 Naming-wise, i wonder if "differentiable" should be reserved for functions (/callables), like in I think "perturb[able]" is a word that's been used in this package when talking about types (e.g. in CRTestUtils here and here). So that could be an option if we wanted to avoid "differentiable" for types. (That said, a quick search suggests that Switft for TF thought "differentiable types" was a good name for "types that can be used as arguments and results of differentiable functions" https://github.com/tensorflow/swift/blob/f0d6c74ef5d016046afc1eac0b07a2f6b74b8fdf/docs/DifferentiableTypes.md) Also, i wonder if we want to flip the sign (for ease of understanding) e.g. |
I don't love the name. One reason for the apparent double negative is that the present implementation fails to |
src/projection.jl
Outdated
# Bool | ||
ProjectTo(::Bool) = ProjectTo{NoTangent}() # same projector as ProjectTo(::AbstractZero) above | ||
is_non_differentiable(::Type{Bool}) = true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason this is seperated from the others?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Long long ago that's where the acorn happened to land in the soil...
src/projection.jl
Outdated
is_non_differentiable(::Type{<:AbstractArray{T}}) where {T} = is_non_differentiable(T) | ||
|
||
function is_non_differentiable(::Type{T}) where {T} # fallback | ||
PT = Base._return_type(ProjectTo, Tuple{T}) # might be Union{} if unstable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kinda want to run this the other way around.
Can we make it so that the is_non_differentiable
is the canonical one and ProjectTo
falls back to that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would make some sense. It's a slightly bigger API change I guess, as this is then what you ought to define (although I imagine 3rd-party non-diff types are pretty rare).
What this way around does do is re-use the logic for e.g. deciding whether all elements of a tuple are non-diff, hence the whole thing is. ProjectTo(::Tuple)
already has to have code for recursing into things. If we reverse it, then is_non_differentiable
will also want to have such code.
src/projection.jl
Outdated
""" | ||
is_non_differentiable(x) == is_non_differentiable(typeof(x)) | ||
|
||
Returns `true` if `x` is known from its type not to have derivatives, else `false`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Returns `true` if `x` is known from its type not to have derivatives, else `false`. | |
Returns `true` if `x` is known from its type that has no tangent space, else `false`. |
Latest commit changes this to return a type, not a function. This means you test In general it returns the |
is_non_differentiable
is_non_differentiable
~~ differential_type
is_non_differentiable
~~ differential_type
differential_type
, to test for non-differentiability
This adds a function to check whether a given type is non-differentiable. The purpose is to let you test whether to take the trivial path for some rule.
It goes by whether
ProjectTo(x)
infers to be a trivial projector. This means it should always be costless, rather than iterating throughx = Any[true, false]
it will give up and returnfalse
.Possibly this should return something whose type is the indication, like
Val(true)
? Or return one ofAbstractZero / Number / Any
, or some trait structIsNonDiff
, seems a bit heavyweight...