-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add axis parameter to lax.scan #4591
Conversation
This looks reasonable to me, but it needs some tests (probably in |
Thanks for exploring this! I think thorough tests are the top priority here. There are some things, like the scan batching rule, that assume the scanned-over axis is always the leading one, and so e.g. it's safe to place the batch axis as the second axis. I think some code with those assumptions will need to be revised. The only way to ferret out all such assumptions is to have really thorough tests. Luckily I don't think there are any fundamental blockers to this generalization! |
Hi @mattjj, could you show me some short snippets of user code that should work. That way I can extend them into unit tests and then check that everything is working, since Jax is a large project and I am having a hard time thinking of tests. I am eager to get this properly tested and merged now that I can focus on it! |
Is there anything blocking here? Would be great to get this one in. Happy to take it on myself. Cheers. |
Thanks for the pings. Sorry for the 6+ month delay time! The blocker is adding thorough tests. I think the ideal tests would cover cases like scanning over different axis indices (maybe making the current test a parametrized test?), autodiff, and vmapping. I'm especially concerned about vmapping, because I suspect the Does that make sense? |
Oh, I just realized I completely missed your message. That makes sense! I'll see if I can take a look at this next week. Thanks! |
@hamzamerzic As it has been a while since I opened this PR, I cleaned a bit the commits and parametrized the current tests as Mathew suggested. Everything is currently working for the normal implementation, which is what I needed, but when I tried to parametrize the tests that deal with batching such as You are more than welcome to continue this PR or throw it all and start from scratch, whichever suits you better 😄 |
I have tried to add a bit more convenience on the scan function by specifying the axis to iterate over. It should basically get rid of the need of transposing the input before doing the scan without a performance loss.
Hoping it can partially fix #2509. I am still not getting the expected performance on #2491 but it think should be faster than working with transposes. (Not really clear why it isn't in my tests)
The basic idea behind it is to get the correct
index_array
/dynamic_index_array
of the appropiate axis instead of axis=0.The current implementation is:
So it does not seem to be any overhead on using axis > 0, and of course the results I have checked are correct. Although I guess there should be some unit tests to cover it.
Please, let me know if this is something you are interested in merging and if you have any ideas to improve it.