Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loop operator #76

Open
dfdx opened this issue Jan 26, 2021 · 1 comment
Open

Loop operator #76

dfdx opened this issue Jan 26, 2021 · 1 comment

Comments

@dfdx
Copy link
Owner

dfdx commented Jan 26, 2021

Current approach to dynamic graphs is to trace function on each execution and select matching tape (with diff operations already added) from cache. This approach has several disadvantages:

  1. Tracing takes time.
  2. Cached tapes occupy memory.

One way to avoid them is to add support for Loop and If operations. Loop looks like the harder one, so this issue is about it.

Representation

Loop operator can be represented with something like this:

mutable struct Loop
    id::Int
    subtape::Tape
    exec_count::Int
end

When executing, this operator would not only repeat its code, but also record the number of repetitions. This number can then be passed to the derivative of the operator for the reverse pass.

Tracing

Currently we trace code by rewriting every call to f(args...) with a call to record_or_recurse!(..., fargs). Every time a primitive function is called, this call is recorded to the tape, leading to a fully unrolled trace.

If we want to move loops into a separate operator, we need to treat them in some other way. Assuming we can detect start and end of the loop (which doesn't seem hard in IRTools's code representation), I see 2 possible approaches so far:

  1. Trace every execution into subtape as usually and then simplify this subtape.
  2. Stop tracing after the first execution.

Switching between the outer and inner (sub)tapes can be implemented similar to switching between execution frames.

Reverse pass

Should be similar to the generated forward pass code, but with updating the derivatives.

As an option, we may generate simple gradient function for each loop write during differentiating the outer tape. This would mix up differentiation and compilation stages though - so Julian, but pretty dangerous strategy.

Loop operator outputs

Loop operator can return a tuple of all changed variables. This tuple can then be destructured in a usual way.

Note that it shouldn't break CSE or any other optimizations.

@dfdx
Copy link
Owner Author

dfdx commented Jul 3, 2021

Loop tracing is now supported via Ghost.jl. Differentiating is on the roadmap.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant