-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Possible to write rules for methods not collections of methods? #471
Comments
I do not believe it is possible, in the language that is Julia v1.7 |
Is this to say that it would have been possible in 1.6, but is no longer in 1.7, but that it was never possible? |
Also, why do you believe it not to be possible? Is it that you're not sure how it can be done, or do you have a particular reason to think that it cannot? |
Related: JuliaDiff/ChainRules.jl#237 |
I mean to say that it has never been possible in julia version less than or equal to 1.7.
It is that I don't see how it can be done. Thus the "I do not believe". Possibily something can be done in Zygote, by doing the source code tranform at a different part of the compilation pipeline to where it is done now. |
Cool.
Also cool.
Agreed regarding the IR, but I've just had a play around, and I think something like this might do it. Not sure if I'm really allowed to use using ChainRules
# A specialised method. Without this, a rule isn't hit. With this, a rule is hit.
# ChainRules.@non_differentiable sin(::Float64)
@generated function has_a_rule(f, args...)
T = Tuple{f, args...}
# Find the primal method which would be hit by the types provided.
primal_method = which(T)
# Find the rrule that it would hit by the types provided.
rrule_method = which(ChainRules.rrule, T)
# Obtain the signature of the rrule method without reference to the rrule function itself.
rrule_sig = Tuple{rrule_method.sig.parameters[2:end]...}
# Find the method of the original function those arguments would hit.
rrule_method = try
which(rrule_sig)
catch
nothing
end
# Check to see if they're the same method.
use_chain_rule = rrule_method !== nothing && primal_method === rrule_method
return use_chain_rule ? :true : :false
end
has_a_rule(sin, 5.0) edit: per the example above: using LinearAlgebra
my_sum(x::AbstractMatrix) = sum(x)
my_sum(x::Diagonal) = sum(diag(x))
ChainRules.@non_differentiable my_sum(::AbstractMatrix)
has_a_rule(my_sum, randn(5, 5)) # returns true
has_a_rule(my_sum, Diagonal(randn(5))) # returns false |
Pretty sure you are not. |
Not ideal, but at least there's a precedent 😂 |
In some sense what we have right now is kinda like this but the complement. If one sticks to a policy of any time you implement a method you either implement a rrule or or |
Note: @simeonschaub is essentially trying to do this in FluxML/Zygote.jl#909 |
I'm not advocating for anything here. I'm just stating some facts, and wish to ascertain whether a particular design choice is technically feasible or not.
AD Effectively Operates on Individual Methods
First, recall that AD operates on the level of methods -- (in the absence of a generically-typed rule) AD does not know anything about the semantics of a function, it just sees a collection of bits of code.
For example, running Zygote on
will produce something equivalent to
If I now add another method
Zygote will automatically specialise and produce something like
While Zygote uses Julia's multiple dispatch system to achieve this behaviour via a single loosely-typed generated function, it produces different outputs depending upon the method of a function hit by the types of the arguments, rather than simply the type of the arguments. It's able to do this because generated functions have access to the IR associated with a particular method.
ChainRules Operates More Generically
This is well understood, but worth pointing out. As implemented in all existing AD systems which support them, our rules apply to all methods of a function to which the
rrule
applies. So in themy_sum
examples above, if I were to defineit will apply to both methods, blocking codegen for the more specialised method.
Would it be possible to make rules apply to methods also?
To take Zygote as a concrete example, would it be technically feasible to make Zygote treat rules as being equivalent to its own codegen-ed code, so that if one defines the
rrule
above, it is only hit when themy_sum(::AbstractMatrix)
method is hit, but leaves codegen to proceed as per usual formy_sum(::Diagonal)
method?Specifically
The text was updated successfully, but these errors were encountered: