-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
lax.map #1113
Comments
Great idea. Thinking about it only for a moment, the only thing this implementation might miss out on compared to writing our own Does that sound right? Want to make a PR adding this to lax_control_flow.py? |
Yes, totally agreed! I will put together a PR. |
See #1118 |
Hi all, Probably a silly question, but is there a way in lax.scan to specify which input array axes to map over? My function takes two arguments as inputs, and I would like to vectorize the second argument's leading dimension only. With vmap, I would do something like this:
Thanks for helping |
A great question! (By the way, it can be hard to notice questions on already-closed issues.) There's an open feature request for this in #2509, and even a PR in #4591 that we haven't had a chance to review yet. If that sounds like the feature you want, maybe try out the PR branch, and even see if you can help contribute tests. |
Thanks for the prompt reply! Sorry, I am not used to discussions on Github yet. I will definitely give a try to PR in #4591. |
Sometimes, it makes more sense to repeatedly apply a function the old fashioned way in a loop, rather than the auto-batching
vmap
way. Use cases that come to mind:cond
andwhile_loop
.This can be done with
lax.scan
, but it's a little awkward. I propose adding alax.map
primitive that works like Python's builtinmap
. The implementation could look something like:The text was updated successfully, but these errors were encountered: