Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lax.map #1113

Closed
shoyer opened this issue Aug 5, 2019 · 6 comments · Fixed by #1118
Closed

lax.map #1113

shoyer opened this issue Aug 5, 2019 · 6 comments · Fixed by #1118

Comments

@shoyer
Copy link
Collaborator

shoyer commented Aug 5, 2019

Sometimes, it makes more sense to repeatedly apply a function the old fashioned way in a loop, rather than the auto-batching vmap way. Use cases that come to mind:

  • Reduced memory usage.
  • Heterogeneous computation, if using primitives like cond and while_loop.

This can be done with lax.scan, but it's a little awkward. I propose adding a lax.map primitive that works like Python's builtin map. The implementation could look something like:

def map(f, xs):
  g = lambda _, x: ((), f(x))
  _, ys = lax.scan(g, (), xs)
  return ys
@mattjj
Copy link
Collaborator

mattjj commented Aug 5, 2019

Great idea. Thinking about it only for a moment, the only thing this implementation might miss out on compared to writing our own lax.map primitive is that XLA might not notice that the loop is parallelizable (since we'll build a computation that threads through a trivial carry of () and ignores it, so XLA would need to notice that we're ignoring the carry). But maybe XLA is smart enough to prune that away, and in any case we can always improve the implementation later.

Does that sound right?

Want to make a PR adding this to lax_control_flow.py?

@shoyer
Copy link
Collaborator Author

shoyer commented Aug 5, 2019

Thinking about it only for a moment, the only thing this implementation might miss out on compared to writing our own lax.map primitive is that XLA might not notice that the loop is parallelizable (since we'll build a computation that threads through a trivial carry of () and ignores it, so XLA would need to notice that we're ignoring the carry). But maybe XLA is smart enough to prune that away, and in any case we can always improve the implementation later.

Yes, totally agreed!

I will put together a PR.

shoyer added a commit to shoyer/jax that referenced this issue Aug 5, 2019
@shoyer
Copy link
Collaborator Author

shoyer commented Aug 5, 2019

See #1118

@alelovato
Copy link

Hi all,

Probably a silly question, but is there a way in lax.scan to specify which input array axes to map over? My function takes two arguments as inputs, and I would like to vectorize the second argument's leading dimension only. With vmap, I would do something like this:

ke = vmap(kinetic_energy, in_axes=(None, 0), out_axes=(0))(params, inputs)

Thanks for helping

@mattjj
Copy link
Collaborator

mattjj commented Dec 19, 2020

A great question! (By the way, it can be hard to notice questions on already-closed issues.)

There's an open feature request for this in #2509, and even a PR in #4591 that we haven't had a chance to review yet. If that sounds like the feature you want, maybe try out the PR branch, and even see if you can help contribute tests.

@alelovato
Copy link

Thanks for the prompt reply!

Sorry, I am not used to discussions on Github yet. I will definitely give a try to PR in #4591.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants