-
Notifications
You must be signed in to change notification settings - Fork 913
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Decreasing autograd memory usage #219
Comments
Are you sure there's a leak here? Neither version seems to end at exactly zero net memory usage in the plots (and indeed the manual ones go negative at first), and we would expect autograd to require more memory because it has to trace the forward computation (not just store the values, but box each of them, etc.). This doesn't seem out of line with my expectations, and Python memory deallocation and measurement can be pretty complicated [1] [2]. |
Sorry, I don't understand. Doesn't autograd end with less net memory usage than pytorch in that plot? Where is the memory leak shown? |
Here's a gist on a more "real-world" benchmark, taking the gradients of an RNN. https://gist.github.com/alexbw/7b2a0682f65dd1bcb7120ca2d47a2823 Here's the memory usage: |
For the RNN benchmark, here's separating the forwards and backwards pass of autograd, versus just calling grad(f)(*args). (apologies, benchmark code is getting a little hairy) |
It's probably from the outgrads being accumulated in a list and summed all at once rather than incrementally summed. |
I just pushed a commit that does in-place gradient accumulation; maybe check how the memory looks with that (though I might not have done it right). |
Btw, thanks for these benchmarks! The latter ones are quite stark. (At first I was just focusing on the 'leak' idea, but not enough on the total memory usage issue.) |
Is there a potential factor of 2 because of float32 vs float64, or are all these propagating the same dtypes throughout? I think in-place accumulation makes a lot of sense for computation graphs with high fan-out of some nodes (like an RNN that reuses the same parameters many times; great test!). Fortunately @dougalm's vector space type system made the initial implementation a breeze! |
FYI (just updating the thread with a conversation Matt and I had), PyTorch does NOT do anything special with gradients of indexing operations. They materialize the full-sized mostly-zero gradient, and add that in directly to their equivalent of It seems this is the common case for their users (and perhaps autograd's users!), and hasn't given rise to any serious performance complaints. They're aware it's suboptimal, but it doesn't seem that it's anywhere near the top of their todo list. |
With matt's With current master (3d018e8) I added the benchmark file I'm using as a branch from master, benchmark-rnn |
I am fairly confident I've figured out the memory issue. I checked the results out backward_pass from master to backward_pass with this addition and the results were identical in my use case. |
I don't see any change in peak memory usage in my RNN benchmark. In order to replicate, you can do two things –
Could you post a version of your benchmark you're comfortable sharing? |
These are two separate issues. @Etragas's change in #221 ensures that the values in an outgrads list associated with a node are garbage collected after they're finished being used. The issue that @alexbw is raising is that it's expensive to accumulate the outgrads in a list to begin with (in common use cases with high fan-out). That is, @Etragas saves memory across multiple nodes' outgrads lists, whereas @alexbw is concerned about memory usage for a single node's outgrad list. I believe @dougalm is implementing a different accumulation strategy for the outgrads that will primarily address the issue @alexbw raised, though the change @Etragas suggested (explicitly deleting the entry from the outgrads dict) should also be included, since we don't need to keep those around. |
Old: New (bb8c0bc): |
As far as I can tell, the memory consumption during the backward pass is now essentially flat. Thoughts? |
I think the remaining difference in memory in these plots is from float64 vs float32. @alexbw showed me some other plots in which he forced PyTorch to use float64, and in those plots autograd and PyTorch were essentially identical in terms of total memory usage. To make autograd maintain float32s in computations (given float32 inputs), we need to go through all our VJPs systematically and avoid any up-casting. However, I think this fan-out memory usage issue is now solved with bb8c0bc, so I'm going to close this issue. |
On the issue of float32(/float16) support, I believe that This is a reflection of numpy's type system, which respects the dtype of arrays and seems to largely ignore the dtype of scalar values (even if they are wrapped as an ndarray with shape ()). I think this needs to be thoroughly tested but assuming the above is correct, we could say in the docs that float32 is supported but only for arrays. In [1]: from autograd import grad
In [2]: import autograd.numpy as np
In [3]: def f(x):
...: return x**2
...:
In [4]: grad(f)(np.array([3], dtype=np.float32))
Out[4]: array([ 6.], dtype=float32)
In [5]: grad(f)(np.array(3, dtype=np.float32))
AssertionError:
Grad of power returned unexpected vector space
Vector space is ArrayVSpace_{'shape': (), 'size': 1, 'dtype': dtype('float64'), 'scalartype': <class 'float'>}
Expected ArrayVSpace_{'shape': (), 'size': 1, 'dtype': dtype('float32'), 'scalartype': <class 'float'>} |
I don't mean "memory leak" in terms of unreachable memory after the Python process quits, I mean memory that is being allocated in the backwards pass, when it should be being freed. I mentioned this problem in #199 , but I think it should be opened as an issue.
For a simple function
and a procedure to measure memory usage
and a manual gradient of the same function
I get the following memory usage profile:
If I replace the dot gradient with the ones used in the manual code, I get the same memory profile, nothing improves.
If I replace the dot product with element-wise multiply, I get a different memory profile, but still not what I would expect:
I would love to help figure this out, but I'm not sure where to start. First thing is of course to document the problem.
The text was updated successfully, but these errors were encountered: