Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible to write rules for methods not collections of methods? #471

Open
willtebbutt opened this issue Sep 24, 2021 · 10 comments
Open

Possible to write rules for methods not collections of methods? #471

willtebbutt opened this issue Sep 24, 2021 · 10 comments

Comments

@willtebbutt
Copy link
Member

willtebbutt commented Sep 24, 2021

I'm not advocating for anything here. I'm just stating some facts, and wish to ascertain whether a particular design choice is technically feasible or not.

AD Effectively Operates on Individual Methods

First, recall that AD operates on the level of methods -- (in the absence of a generically-typed rule) AD does not know anything about the semantics of a function, it just sees a collection of bits of code.

For example, running Zygote on

my_sum(x::AbstractMatrix) = sum(x)

will produce something equivalent to

function Zygote._pullback(ctx::AContext, ::typeof(my_sum), x::AbstractMatrix)
    y, sum_pullback = Zygote._pullback(ctx, sum, x)
    function my_sum_pullback(dy)
        _, dx = sum_pullback(dy)
        return nothing, dx
    end
    return y, my_sum_pullback
end

If I now add another method

my_sum(x::Diagonal) = sum(diag(x))

Zygote will automatically specialise and produce something like

function Zygote._pullback(ctx::AContext, ::typeof(my_sum), x::Diagonal)
    tmp, diag_pullback = Zygote._pullback(ctx, diag, x)
    y, sum_pullback = Zygote._pullback(ctx, sum, tmp)
    function my_sum_pullback(dy)
        _, dtmp = sum_pullback(dy)
        _, dx = diag_pullback(dtmp)
        return nothing, dx
    end
end

While Zygote uses Julia's multiple dispatch system to achieve this behaviour via a single loosely-typed generated function, it produces different outputs depending upon the method of a function hit by the types of the arguments, rather than simply the type of the arguments. It's able to do this because generated functions have access to the IR associated with a particular method.

ChainRules Operates More Generically

This is well understood, but worth pointing out. As implemented in all existing AD systems which support them, our rules apply to all methods of a function to which the rrule applies. So in the my_sum examples above, if I were to define

function ChainRulesCore.rrule(::typeof(my_sum), x::AbstractMatrix)
    # some code
end

it will apply to both methods, blocking codegen for the more specialised method.

Would it be possible to make rules apply to methods also?

To take Zygote as a concrete example, would it be technically feasible to make Zygote treat rules as being equivalent to its own codegen-ed code, so that if one defines the rrule above, it is only hit when the my_sum(::AbstractMatrix) method is hit, but leaves codegen to proceed as per usual for my_sum(::Diagonal) method?

Specifically

# hits rrule because my_sum(::AbstractMatrix) is most specialised method applicable
# to Matrix{Float64}.
Zygote.pullback(my_sum, randn(5, 5))

# does not hit rrule because my_sum(::Diagonal) applies to Diagonal{Float64, Vector{Float64}}.
Zygote.pullback(my_sum, Diagonal(randn(5)))
@willtebbutt willtebbutt changed the title Possible to write rules for methods not functions? Possible to write rules for methods not collections of methods? Sep 24, 2021
@oxinabox
Copy link
Member

I do not believe it is possible, in the language that is Julia v1.7

@willtebbutt
Copy link
Member Author

I do not believe it is possible, in the language that is Julia v1.7

Is this to say that it would have been possible in 1.6, but is no longer in 1.7, but that it was never possible?

@willtebbutt
Copy link
Member Author

willtebbutt commented Sep 24, 2021

Also, why do you believe it not to be possible? Is it that you're not sure how it can be done, or do you have a particular reason to think that it cannot?

@sethaxen
Copy link
Member

Related: JuliaDiff/ChainRules.jl#237

@oxinabox
Copy link
Member

Is this to say that it would have been possible in 1.6, but is no longer in 1.7, but that it was never possible?

I mean to say that it has never been possible in julia version less than or equal to 1.7.
But that I can't say how the language might change in 1.8 yet.

Also, why do you believe it not to be possible? Is it that you're not sure how it can be done, or do you have a particular reason to think that it cannot?

It is that I don't see how it can be done. Thus the "I do not believe".

Possibily something can be done in Zygote, by doing the source code tranform at a different part of the compilation pipeline to where it is done now.
Right now it is too early: lowered IR is before types are known, so can't work out what method is being hit anywhere in the first place.
and I suspect typed IR is too late: by that stage it isn't working with method, but with MethodInstances (specialized on concrete types)

@willtebbutt
Copy link
Member Author

willtebbutt commented Sep 24, 2021

I mean to say that it has never been possible in julia version less than or equal to 1.7.
But that I can't say how the language might change in 1.8 yet.

Cool.

It is that I don't see how it can be done. Thus the "I do not believe".

Also cool.

Right now it is too early: lowered IR is before types are known, so can't work out what method is being hit anywhere in the first place.

Agreed regarding the IR, but Methods have the types in their signature, do they not?

I've just had a play around, and I think something like this might do it. Not sure if I'm really allowed to use which inside a generated function though... possibly I need some backedges.

using ChainRules

# A specialised method. Without this, a rule isn't hit. With this, a rule is hit.
# ChainRules.@non_differentiable sin(::Float64)

@generated function has_a_rule(f, args...)
    T = Tuple{f, args...}

    # Find the primal method which would be hit by the types provided.
    primal_method = which(T)

    # Find the rrule that it would hit by the types provided.
    rrule_method = which(ChainRules.rrule, T)

    # Obtain the signature of the rrule method without reference to the rrule function itself.
    rrule_sig = Tuple{rrule_method.sig.parameters[2:end]...}

    # Find the method of the original function those arguments would hit.
    rrule_method = try
        which(rrule_sig)
    catch
        nothing
    end

    # Check to see if they're the same method.
    use_chain_rule = rrule_method !== nothing && primal_method === rrule_method
    return use_chain_rule ? :true : :false
end

has_a_rule(sin, 5.0)

edit: per the example above:

using LinearAlgebra

my_sum(x::AbstractMatrix) = sum(x)
my_sum(x::Diagonal) = sum(diag(x))
ChainRules.@non_differentiable my_sum(::AbstractMatrix)

has_a_rule(my_sum, randn(5, 5)) # returns true
has_a_rule(my_sum, Diagonal(randn(5))) # returns false

@oxinabox
Copy link
Member

Not sure if I'm really allowed to use which inside a generated function though... possibly I need some backedges.

Pretty sure you are not.
Zygote basically does, and does put in the backedges.
But note how very unreliable it is at picking up new and updated methods after calls.

@willtebbutt
Copy link
Member Author

Not ideal, but at least there's a precedent 😂

@oxinabox
Copy link
Member

In some sense what we have right now is kinda like this but the complement.

If one sticks to a policy of any time you implement a method you either implement a rrule or or @opt_out of an rrule with the same signature,
then you basically get this.

@willtebbutt
Copy link
Member Author

Note: @simeonschaub is essentially trying to do this in FluxML/Zygote.jl#909

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants