-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for buffer donation (input/output aliasing) #1733
Comments
Related: #1273 |
/sub |
Should we add |
Or my dream API: |
I think we can make something similar to that dream happen. One minor modification, perhaps: I am imagining that |
We should think through what the common use case looks like. I have a slight aesthetic bias towards |
I'd argue that Even if there are many outstanding references, |
This would be useful for fastar. A note on the API: is it possible that most use cases are in the form of a fixed-point iteration? Gradient descent and the fastar use case are. We could provide a high-level API to donation like def in_place_fixed_point_iteration(fun, n_iter, x):
x = x.copy() # Protect the input array
for _ in range(n_iter):
x = fun(jax.move(x))
return x and even make |
@hawkinsp I wasn't suggesting anything less explicit. The suggestion (which was tangential to my main point of +1'ing |
That said, we probably don't want it to be an error in that case, since it might be common to write something like x = ...
y = f(jax.move(x))
... # more stuff, x still in scope! To do any meaningful checking against reference counts, the user code would have to look more like x = jax.move(x)
y = f(x)
... EDIT: one more variant: x = jax.move(x) # basically like x = [x]
y = f(x.donate()) # basically like x.pop() But that just seems annoying. So ignore my suggestion! But perhaps not the fact that it wasn't arguing for an implicit API :) |
@j-towns The fixed-point iteration pattern is probably the most common one, and in that case we have even more information available (not just that these buffers are being donated, but also that it'd be smart to identify the input buffers with the output buffers in a particular way). But I think the ideal is where we have a more flexible API, like |
I like the @hawkinsp thoughts here:
|
Myself and several other colleagues at DM are pretty excited about some jax.move variant being implemented as the 2x model params+optimizer stats memory overhead is quite significant for us. Has there been any further discussion on this feature? |
I think at this point someone just needs to do it! I can give it a shot next week (or post here if something else comes up :)). |
@skye: Did you have a chance to look into buffer donation? If the jax team is busy with other things right now, we might be able to help with this one. |
@ibab: @tomhennigan is already on this. |
And yes, I didn't get a chance to look at this after all. I forgot to ping here, sorry! |
FYI this was fixed for TPU in #2936 and XLA team are in the process of supporting this on GPU right now. |
my current understanding is that buffer donation works on TPUs and GPUs (since jax 0.1.73) but not on CPUs. are there plans to support CPUs as well? |
This issue still applies, but only on CPU. |
This issue is long fixed! |
Currently JAX cannot reuse input buffers to a computation for outputs. This means that for a typical neural network training step, we require enough space to store 2 copies of the weights simultaneously in memory.
XLA supports input/output aliasing, which would allow JAX to tell XLA that it may reuse the input weight buffers for output weights, but we haven't yet enabled it from JAX.
There are two basic ways we could try to use XLA's support:
a) opportunistically, i.e., if we detect that the reference count of a buffer is 1 at execution time, we could allow XLA to reuse it. This is somewhat problematic in that it's pretty hard to tell whether the reference count is truly 1 during the
Execute()
call, because the caller may hold references.One way around this might be to distinguish between (a) the Python
Execute()
call, and (b) the time that execution actually takes place, by which time the Python references may have been dropped.b) explicitly. Here the user would provide an argument to
jit
orpmap
, something like:This would be a promise by the user that they are done with the buffers passed in certain argument positions, and that either (a) the called computation may reuse them for outputs, or (b) they will be freed.
The explicit option seems the simplest to start with, and has the advantage of simplicity of implementation.
The text was updated successfully, but these errors were encountered: