Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add differential_type, to test for non-differentiability #528

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jan 11, 2022

This adds a function to check whether a given type is non-differentiable. The purpose is to let you test whether to take the trivial path for some rule.

It goes by whether ProjectTo(x) infers to be a trivial projector. This means it should always be costless, rather than iterating through x = Any[true, false] it will give up and return false.

Possibly this should return something whose type is the indication, like Val(true)? Or return one of AbstractZero / Number / Any, or some trait struct IsNonDiff, seems a bit heavyweight...

@codecov-commenter

This comment has been minimized.

@nickrobinson251
Copy link
Contributor

nickrobinson251 commented Jan 11, 2022

🚲 Naming-wise, i wonder if "differentiable" should be reserved for functions (/callables), like in @non_differentiable f(x, y) @non_differentiable T(x)? Partly because of the existence of the @non_differentiable macro; on naming alone, i'd guessed this was related to @non_differentiable (not so much ProjectTo), i.e. is_non_differentiable(f, args...) if and only if we had @non_differentiable f(args...).

I think "perturb[able]" is a word that's been used in this package when talking about types (e.g. in CRTestUtils here and here). So that could be an option if we wanted to avoid "differentiable" for types.

(That said, a quick search suggests that Switft for TF thought "differentiable types" was a good name for "types that can be used as arguments and results of differentiable functions" https://github.com/tensorflow/swift/blob/f0d6c74ef5d016046afc1eac0b07a2f6b74b8fdf/docs/DifferentiableTypes.md)

Also, i wonder if we want to flip the sign (for ease of understanding) e.g. is_differentiable not is_non_differentiable (double-negatives like "is_non_diff(T) == false so T is not non-differentiable" can be hard to follow sometimes)?

@mcabbott
Copy link
Member Author

I don't love the name.

One reason for the apparent double negative is that the present implementation fails to false. It answers the question "are we certain we can take the trivial path here?"

Comment on lines 172 to 174
# Bool
ProjectTo(::Bool) = ProjectTo{NoTangent}() # same projector as ProjectTo(::AbstractZero) above
is_non_differentiable(::Type{Bool}) = true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this is seperated from the others?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long long ago that's where the acorn happened to land in the soil...

is_non_differentiable(::Type{<:AbstractArray{T}}) where {T} = is_non_differentiable(T)

function is_non_differentiable(::Type{T}) where {T} # fallback
PT = Base._return_type(ProjectTo, Tuple{T}) # might be Union{} if unstable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kinda want to run this the other way around.
Can we make it so that the is_non_differentiable is the canonical one and ProjectTo falls back to that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would make some sense. It's a slightly bigger API change I guess, as this is then what you ought to define (although I imagine 3rd-party non-diff types are pretty rare).

What this way around does do is re-use the logic for e.g. deciding whether all elements of a tuple are non-diff, hence the whole thing is. ProjectTo(::Tuple) already has to have code for recursing into things. If we reverse it, then is_non_differentiable will also want to have such code.

"""
is_non_differentiable(x) == is_non_differentiable(typeof(x))

Returns `true` if `x` is known from its type not to have derivatives, else `false`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Returns `true` if `x` is known from its type not to have derivatives, else `false`.
Returns `true` if `x` is known from its type that has no tangent space, else `false`.

@mcabbott
Copy link
Member Author

mcabbott commented Jul 8, 2022

Latest commit changes this to return a type, not a function. This means you test differential_type(x) <: AbstractZero to see whether x is known to have no derivative.

In general it returns the T in ProjectTo{T}. I'm not certain that has any other uses. But at least it's more obviously something read out from the projection machinery, rather than being an independent concept.

@mcabbott mcabbott changed the title RFC: add is_non_differentiable Add ~~is_non_differentiable~~ differential_type Jul 14, 2022
@mcabbott mcabbott changed the title Add ~~is_non_differentiable~~ differential_type Add differential_type, to test for non-differentiability Jul 14, 2022
@mcabbott mcabbott marked this pull request as ready for review July 14, 2022 20:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants