Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Primitive to sequentially execute a function inside a vmapped function #7199

Open
pl-fuchs opened this issue Jul 6, 2021 · 7 comments
Open
Assignees
Labels
enhancement New feature or request P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)

Comments

@pl-fuchs
Copy link

pl-fuchs commented Jul 6, 2021

I am writing some longer algorithms which I like to vmap. I stumbled over some problems when combining vmap with the host_callback module and the lax.cond function:

  1. vmap of cond does not work when the branches are not side-effect-free
  2. vmap of cond of can be very ineffective if one branch is much longer but only called rarely
  3. There is no trivial batching rule for the host_callback.call function

I think a simple solution would be to implement a stop_vmap or sequential_vmap decorator. This decorator would define a batching rule, such that

@vmap
@stop_vmap
def some_fun(*args):
  # Some operations ...
  return results

would be the same as writing

def some_fun(*batched_args):

  def body_fun(*args):
    # Some operation...
    return results

  return lax.map(lambda args: body_fun(*args), batched_args)

The advantage of the decorator would be that some_fun could be used inside a much bigger vmapped function.

@pl-fuchs pl-fuchs added the enhancement New feature or request label Jul 6, 2021
@froystig
Copy link
Member

The "sequential_vmap" concept seems like it would be easy to implement given a more general custom batching interface, which is something that @mattjj and I have considered before.

Short of that, it might be worth considering offering the special case, by writing a single higher-order primitive whose batching rule is essentially a call to lax.map.

@froystig froystig added the P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional) label Jul 29, 2021
@pharringtonp19
Copy link

@froystig Would breaking a vmap into a sequential_vmap alleviate memory usage?

@froystig
Copy link
Member

The original request is about modifying the behavior of a function under vmap. If you're asking about replacing the use of vmap entirely with a sequential map, then yes: we have lax.map and using it might require less memory relatively. Did I read your question correctly?

@pharringtonp19
Copy link

@froystig Apologies, I see that this was not the right place to ask that question -- You did answer my question, though, so thanks

@froystig
Copy link
Member

No problem. Glad that helped!

@shoyer
Copy link
Collaborator

shoyer commented Aug 22, 2021

The "sequential_vmap" concept seems like it would be easy to implement given a more general custom batching interface, which is something that @mattjj and I have considered before.

More general "custom batching" could be interesting for use-cases like host_callback.call, where some external library that may support its own parallelism strategies. E.g., I think this could be useful for @ianwilliamson's MEEP wrapper: NanoComp/meep#1569

Short of that, it might be worth considering offering the special case, by writing a single higher-order primitive whose batching rule is essentially a call to lax.map.

This could definitely be a good place to start, even if only as the first step towards the general solution.

@froystig
Copy link
Member

This could definitely be a good place to start, even if only as the first step towards the general solution.

I tend to agree, like having written linear_call as a first step towards custom transposition. Perhaps worth trying to see what it surfaces.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)
Projects
None yet
Development

No branches or pull requests

4 participants