Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for buffer donation (input/output aliasing) #1733

Closed
hawkinsp opened this issue Nov 21, 2019 · 21 comments
Closed

Add support for buffer donation (input/output aliasing) #1733

hawkinsp opened this issue Nov 21, 2019 · 21 comments
Labels
CPU Issues related to the CPU compiler/runtime enhancement New feature or request P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional) XLA

Comments

@hawkinsp
Copy link
Collaborator

Currently JAX cannot reuse input buffers to a computation for outputs. This means that for a typical neural network training step, we require enough space to store 2 copies of the weights simultaneously in memory.

XLA supports input/output aliasing, which would allow JAX to tell XLA that it may reuse the input weight buffers for output weights, but we haven't yet enabled it from JAX.

There are two basic ways we could try to use XLA's support:
a) opportunistically, i.e., if we detect that the reference count of a buffer is 1 at execution time, we could allow XLA to reuse it. This is somewhat problematic in that it's pretty hard to tell whether the reference count is truly 1 during the Execute() call, because the caller may hold references.

One way around this might be to distinguish between (a) the Python Execute() call, and (b) the time that execution actually takes place, by which time the Python references may have been dropped.

b) explicitly. Here the user would provide an argument to jit or pmap, something like:

@partial(jit, donated_argnums=(1,2,7))
def f(...):
  ...

This would be a promise by the user that they are done with the buffers passed in certain argument positions, and that either (a) the called computation may reuse them for outputs, or (b) they will be freed.

The explicit option seems the simplest to start with, and has the advantage of simplicity of implementation.

@romanngg
Copy link
Contributor

Related: #1273

@hawkinsp hawkinsp added the enhancement New feature or request label Nov 21, 2019
@afrozenator
Copy link

/sub

@skye
Copy link
Member

skye commented Nov 21, 2019

Should we add donated_argnums to the call site instead of to jit? That'd make it more clear when a DeviceArray is no longer safe to use.

@skye
Copy link
Member

skye commented Nov 22, 2019

Or my dream API: jit(f)(jax.move(x))
This maybe isn't a great option though since it'd only make sense at jit/pmap boundaries, but there's nothing stopping you from calling jax.move anywhere.

@hawkinsp
Copy link
Collaborator Author

I think we can make something similar to that dream happen.

One minor modification, perhaps:
jit(f)(x.mark_for_donation())
?

I am imagining that DeviceValue.mark_for_donation() sets a flag on the buffer that says the next computation that receives it as input consumes it. I am not 100% happy with that as an API but it would do the job.

@mattjj
Copy link
Collaborator

mattjj commented Nov 22, 2019

We should think through what the common use case looks like. I have a slight aesthetic bias towards jax.move(x) (which could just return a wrapper object with a similar logic to what Peter said). I suspect we'll want to be able to tree-map this function easily. It's worth thinking about if there are any Python refcount games worth playing (like if jax.move(x) could verify that there aren't any other references to x), or if they're all too tricky.

@hawkinsp
Copy link
Collaborator Author

I'd argue that jax.move (or whatever we call it) should not be in the business of looking at reference counts. The whole point of the API is that it is completely explicit. This does not preclude the existence of a more magical and implicit way to reclaim buffers, but the explicit API should be predictable and not try to play tricks with reference counts.

Even if there are many outstanding references, jax.move should still take ownership of the buffer. Another way to think about it is move deletes the original DeviceArray object (just as if you called .delete()) and returns you a new object with the same backing buffer, tagged in some way so that the jit dispatch logic knows the computation may consume it. One way to do this would be to return a new subclass of DeviceValue (say, DonatedDeviceArray) that is not a DeviceArray (so it can't be confused with one), but is known to the dispatch logic.

@j-towns
Copy link
Contributor

j-towns commented Nov 22, 2019

This would be useful for fastar. A note on the API: is it possible that most use cases are in the form of a fixed-point iteration? Gradient descent and the fastar use case are. We could provide a high-level API to donation like

def in_place_fixed_point_iteration(fun, n_iter, x):
  x = x.copy()  # Protect the input array
  for _ in range(n_iter):
    x = fun(jax.move(x))
  return x

and even make move private if we think that in_place_fixed_point_iteration covers all likely use cases.

@mattjj
Copy link
Collaborator

mattjj commented Nov 22, 2019

@hawkinsp I wasn't suggesting anything less explicit. The suggestion (which was tangential to my main point of +1'ing jax.move) was an explicit API with error checking. That is, it's not that jax.move only causes buffer donation when it can; rather, it's that it always causes buffer donation (i.e. the explicit version we all want) but may also be able to catch errors (if we decide that having an alias to x when calling jax.move(x) is actually an error).

@mattjj
Copy link
Collaborator

mattjj commented Nov 22, 2019

That said, we probably don't want it to be an error in that case, since it might be common to write something like

x = ...
y = f(jax.move(x))
... # more stuff, x still in scope!

To do any meaningful checking against reference counts, the user code would have to look more like

x = jax.move(x)
y = f(x)
...

EDIT: one more variant:

x = jax.move(x)  # basically like x = [x]
y = f(x.donate())  # basically like x.pop()

But that just seems annoying.

So ignore my suggestion! But perhaps not the fact that it wasn't arguing for an implicit API :)

@mattjj
Copy link
Collaborator

mattjj commented Nov 22, 2019

@j-towns The fixed-point iteration pattern is probably the most common one, and in that case we have even more information available (not just that these buffers are being donated, but also that it'd be smart to identify the input buffers with the output buffers in a particular way). But in_place_fixed_point_iteration seems too restrictive as an API (basically back to "framework controlled training loops").

I think the ideal is where we have a more flexible API, like jax.move or x.mark_for_donation, that lets user code implement in_place_fixed_point_iteration with the optimal buffer-efficiency.

@mattjj
Copy link
Collaborator

mattjj commented Nov 22, 2019

I like the @hawkinsp thoughts here:

Even if there are many outstanding references, jax.move should still take ownership of the buffer. Another way to think about it is move deletes the original DeviceArray object (just as if you called .delete()) and returns you a new object with the same backing buffer, tagged in some way so that the jit dispatch logic knows the computation may consume it. One way to do this would be to return a new subclass of DeviceValue (say, DonatedDeviceArray) that is not a DeviceArray (so it can't be confused with one), but is known to the dispatch logic.

@dm-jrae
Copy link

dm-jrae commented Dec 30, 2019

Myself and several other colleagues at DM are pretty excited about some jax.move variant being implemented as the 2x model params+optimizer stats memory overhead is quite significant for us. Has there been any further discussion on this feature?

@skye
Copy link
Member

skye commented Jan 2, 2020

I think at this point someone just needs to do it! I can give it a shot next week (or post here if something else comes up :)).

@ibab
Copy link
Contributor

ibab commented Feb 7, 2020

@skye: Did you have a chance to look into buffer donation? If the jax team is busy with other things right now, we might be able to help with this one.

@girving
Copy link

girving commented Feb 7, 2020

@ibab: @tomhennigan is already on this.

@skye
Copy link
Member

skye commented Feb 7, 2020

And yes, I didn't get a chance to look at this after all. I forgot to ping here, sorry!

@tomhennigan
Copy link
Collaborator

FYI this was fixed for TPU in #2936 and XLA team are in the process of supporting this on GPU right now.

@wiep
Copy link

wiep commented Dec 4, 2020

my current understanding is that buffer donation works on TPUs and GPUs (since jax 0.1.73) but not on CPUs. are there plans to support CPUs as well?

@sudhakarsingh27 sudhakarsingh27 added NVIDIA GPU Issues specific to NVIDIA GPUs P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional) labels Aug 10, 2022
@hawkinsp hawkinsp added XLA CPU Issues related to the CPU compiler/runtime and removed NVIDIA GPU Issues specific to NVIDIA GPUs labels Aug 12, 2022
@hawkinsp
Copy link
Collaborator Author

This issue still applies, but only on CPU.

@hawkinsp
Copy link
Collaborator Author

This issue is long fixed!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CPU Issues related to the CPU compiler/runtime enhancement New feature or request P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional) XLA
Projects
None yet
Development

No branches or pull requests