-
Notifications
You must be signed in to change notification settings - Fork 447
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Autodiff: checkpointing strategy #936
Comments
Hi, I'm wondering how the toggle of this strategy should be added into burn's ad graph? The AD tool |
Hi @AuruTus |
If you are discussing check-pointing strategies, it may be worth considering Jax's approach to AD, explained in You Only Linearize Once, since that can shed some light on what is going on with checkpointing. The idea is to break the the vector-jacobian product into two pieces - I'm going to use Haskell type signatures where -o is linear implication and ! means a variable may be reused (e.g. it is a smooth argument).
Now, I said that you want to treat |
Solved in #1358 |
This is referenced in the project readme. Should it be updated? |
In autodiff, we should have a checkpointing strategy for better memory consumption (see for instance https://www-sop.inria.fr/tropics/papers/DauvergneHascoet06.pdf) .
Currently, for most operations run in the forward pass, a state will be saved for the backward pass. The state often consists of a few tensors, so it is needless to say that they accumulate and use a lot of memory.
A way to use less memory for the backward pass would be to, instead of having kept the state in memory, recompute the forward pass of the operation to re-obtain the state, just before computing its backward pass. This will lead to more computations, but less memory consumption.
This leads to a tradeoff between compute and memory. Some operations, like matrix multiplication, are "compute-bound", meaning the bottleneck is generally the actual computations, while some, such as element-wise multiplication, are "memory-bound", meaning the computation is actually so simple that the moving of data is the bottleneck.
For compute-bound operations, it is better to keep the state than to recompute. But for memory-bound operations, we would benefit from recomputing.
Also, if many operations are tagged as memory-bound, this will greatly help fusing kernels with Burn-Fusion, which will be able to fuse kernels transparently during the backward pass.
The current strategy, where every state is saved, would simply become a specific case of the new strategy, where everything is considered compute-bound.
The text was updated successfully, but these errors were encountered: