You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Current approach to dynamic graphs is to trace function on each execution and select matching tape (with diff operations already added) from cache. This approach has several disadvantages:
Tracing takes time.
Cached tapes occupy memory.
One way to avoid them is to add support for Loop and If operations. Loop looks like the harder one, so this issue is about it.
Representation
Loop operator can be represented with something like this:
When executing, this operator would not only repeat its code, but also record the number of repetitions. This number can then be passed to the derivative of the operator for the reverse pass.
Tracing
Currently we trace code by rewriting every call to f(args...) with a call to record_or_recurse!(..., fargs). Every time a primitive function is called, this call is recorded to the tape, leading to a fully unrolled trace.
If we want to move loops into a separate operator, we need to treat them in some other way. Assuming we can detect start and end of the loop (which doesn't seem hard in IRTools's code representation), I see 2 possible approaches so far:
Trace every execution into subtape as usually and then simplify this subtape.
Stop tracing after the first execution.
Switching between the outer and inner (sub)tapes can be implemented similar to switching between execution frames.
Reverse pass
Should be similar to the generated forward pass code, but with updating the derivatives.
As an option, we may generate simple gradient function for each loop write during differentiating the outer tape. This would mix up differentiation and compilation stages though - so Julian, but pretty dangerous strategy.
Loop operator outputs
Loop operator can return a tuple of all changed variables. This tuple can then be destructured in a usual way.
Note that it shouldn't break CSE or any other optimizations.
The text was updated successfully, but these errors were encountered:
Current approach to dynamic graphs is to trace function on each execution and select matching tape (with diff operations already added) from cache. This approach has several disadvantages:
One way to avoid them is to add support for
Loop
andIf
operations.Loop
looks like the harder one, so this issue is about it.Representation
Loop operator can be represented with something like this:
When executing, this operator would not only repeat its code, but also record the number of repetitions. This number can then be passed to the derivative of the operator for the reverse pass.
Tracing
Currently we trace code by rewriting every call to
f(args...)
with a call torecord_or_recurse!(..., fargs)
. Every time a primitive function is called, this call is recorded to the tape, leading to a fully unrolled trace.If we want to move loops into a separate operator, we need to treat them in some other way. Assuming we can detect start and end of the loop (which doesn't seem hard in IRTools's code representation), I see 2 possible approaches so far:
Switching between the outer and inner (sub)tapes can be implemented similar to switching between execution frames.
Reverse pass
Should be similar to the generated forward pass code, but with updating the derivatives.
As an option, we may generate simple gradient function for each loop write during differentiating the outer tape. This would mix up differentiation and compilation stages though - so Julian, but pretty dangerous strategy.
Loop operator outputs
Loop operator can return a tuple of all changed variables. This tuple can then be destructured in a usual way.
Note that it shouldn't break CSE or any other optimizations.
The text was updated successfully, but these errors were encountered: