Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split reverse mode for Tapir #115

Closed
gdalle opened this issue Apr 5, 2024 · 7 comments
Closed

Split reverse mode for Tapir #115

gdalle opened this issue Apr 5, 2024 · 7 comments

Comments

@gdalle
Copy link

gdalle commented Apr 5, 2024

We've discussed this on Slack with @willtebbutt but I wanted to make sure where we stand on split reverse mode, i.e. separating the forward sweep from the reverse sweep.
The idea is being able to perform multiple reverse sweeps with different seeds after just one forward sweep. A typical example computing a Jacobian (where there is one seed per basis vector of the output space).

My question is the following: is that currently possible with Tapir's rrule? IIUC, the answer is no for functions that mutate their argument (gdalle/DifferentiationInterface.jl#142), but what about simple allocating functions?

I took inspiration from https://github.com/withbayes/Tapir.jl/blob/f5e2b90cd17fd3127dd0fd8dfa617bc112275626/src/interface.jl#L9-L15
to try and write what I called value_and_pullback_split in DifferentiationInterface

function value_and_pullback_split(f, x)
    rule = build_rrule(f, x)
    tf = zero_tangent(f)
    tx = zero_tangent(x)
    out, pb!! = rule(CoDual(f, tf), CoDual(x, tx))
    y = copy(primal(out))
    function pullback(dy)
        dy_righttype = convert(tangent_type(typeof(y)), copy(dy))
        ty = increment!!(tangent(out), dy_righttype)
        new_df, new_dx = pb!!(ty, tf, tx)
        return new_dx
    end
    return y, pullback
end

But the behavior of the resulting closure changes at each call.
For some functions it gives different results:

julia> y, pullback = value_and_pullback_split(copy, [1.0])
([1.0], pullback)

julia> pullback([1.0])
1-element Vector{Float64}:
 1.0

julia> pullback([1.0])
1-element Vector{Float64}:
 3.0

julia> pullback([1.0])
1-element Vector{Float64}:
 6.0

For others it downright errors:

julia> y, pullback = value_and_pullback_split(x -> x .^ 2, [1.0])
([1.0], pullback)

julia> pullback([1.0])
1-element Vector{Float64}:
 2.0

julia> pullback([1.0])
ERROR: BoundsError: attempt to access 1-element Vector{Vector{Float64}} at index [0]
Stacktrace:
      internal @ Unknown
 [4] Pullback
   @ ~/.julia/packages/Tapir/BqxEi/src/interpreter/s2s_reverse_mode_ad.jl:632 [inlined]
 [5] (::var"#pullback#10"{Vector{}, Tapir.Pullback{}, CoDual{}, Vector{}, NoTangent})(dy::Vector{Float64})
   @ Main ./REPL[16]:10
Use `err` to retrieve the full stack trace.
Some type information was truncated. Use `show(err)` to see complete types.

What should I copy to allow for independent pullback calls? Probably out, tf and tx?

@gdalle
Copy link
Author

gdalle commented Apr 5, 2024

Side note: it would be interesting to compare this with the way Enzyme works in split reverse mode

@willtebbutt
Copy link
Member

Right, yes, so: it's not just the arguments which get modified on the reverse-pass -- any intermediate mutable data structures that get modified during the forwards-pass will also get modified during the reverse-pass. For multiple reverse passes to be safe given only a single forwards-pass, I think it would have to be the case that no mutation occurs on the forwards-pass.

I've not thought at all about how you might go about checking this, so I really do think it's the case that you'll have to do a single forwards pass per reverse-pass.

It would definitely be possible to do multiple reverse-passes at the same time if I modified the package to explicitly handle "chunked" reverse-mode, where you pass multiple cotangents back at the same time. This would involve a complete overhaul of the cotangent system and all of the rules though, so it's certainly not happening in the forseeable future.

@gdalle
Copy link
Author

gdalle commented Apr 5, 2024

These intermediate data structures are stored in the rule object?

@willtebbutt
Copy link
Member

willtebbutt commented Apr 5, 2024

Exactly. For example, if you have a function:

function f(x::Vector{Float64})
    y = map(sin, x)
    z = map(cos, y)
    return z
end

there will, somewhere in the rule derived to differentiate this function, be memory for y once you've run the forwards-pass. This memory will be reverted to it's initial state (which is non-deterministic) on the reverse-pass.

@yebai
Copy link
Contributor

yebai commented Apr 5, 2024

To expand on @willtebbutt’a answer, the primary objective of this project is to produce a rewrite of Zygote/ReverseDiff with high performance and rigorous testing. We want to keep the internals transparent and extensible but generally want to avoid building one-hammer for all-nails.

@yebai yebai closed this as not planned Won't fix, can't repro, duplicate, stale Apr 5, 2024
@gdalle
Copy link
Author

gdalle commented Apr 5, 2024

And hypothetically, if we want to compute many many pullbacks with the same forward sweep, would it make sense to deepcopy the rule?

@willtebbutt
Copy link
Member

Possibly, but most likely we would find out there are some intricacies that are annoying for one reason or another.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants