Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add axis parameter to lax.scan #4591

Closed
wants to merge 0 commits into from
Closed

Add axis parameter to lax.scan #4591

wants to merge 0 commits into from

Conversation

alonfnt
Copy link
Contributor

@alonfnt alonfnt commented Oct 15, 2020

I have tried to add a bit more convenience on the scan function by specifying the axis to iterate over. It should basically get rid of the need of transposing the input before doing the scan without a performance loss.

Hoping it can partially fix #2509. I am still not getting the expected performance on #2491 but it think should be faster than working with transposes. (Not really clear why it isn't in my tests)

The basic idea behind it is to get the correct index_array/dynamic_index_array of the appropiate axis instead of axis=0.

The current implementation is:

In [1]: import jax

In [2]: xs = jax.numpy.full((100, 100, 100), 42)

In [3]: f = lambda carry, x: (carry + x.max(), x.max())

In [4]: %timeit jax.lax.scan(f, 0.0, xs)
86.3 ms ± 1.15 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [5]: %timeit jax.lax.scan(f, 0.0, xs, axis=1)
86.2 ms ± 388 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [6]: %timeit jax.lax.scan(f, 0.0, xs, axis=2)
88.3 ms ± 753 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

So it does not seem to be any overhead on using axis > 0, and of course the results I have checked are correct. Although I guess there should be some unit tests to cover it.

In [14]: xs2 = jax.numpy.ones((5,10))

In [15]: jax.lax.scan(f, 0.0, xs2, axis=0)
Out[15]:
(DeviceArray(5., dtype=float32),
 DeviceArray([1., 1., 1., 1., 1.], dtype=float32))

In [16]: jax.lax.scan(f, 0.0, xs2, axis=1)
Out[16]:
(DeviceArray(10., dtype=float32),
 DeviceArray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32))

Please, let me know if this is something you are interested in merging and if you have any ideas to improve it.

@google-cla google-cla bot added the cla: yes label Oct 15, 2020
@hawkinsp
Copy link
Collaborator

hawkinsp commented Oct 15, 2020

This looks reasonable to me, but it needs some tests (probably in lax_control_flow_test.py).

@mattjj
Copy link
Collaborator

mattjj commented Oct 15, 2020

Thanks for exploring this!

I think thorough tests are the top priority here. There are some things, like the scan batching rule, that assume the scanned-over axis is always the leading one, and so e.g. it's safe to place the batch axis as the second axis. I think some code with those assumptions will need to be revised. The only way to ferret out all such assumptions is to have really thorough tests.

Luckily I don't think there are any fundamental blockers to this generalization!

@alonfnt alonfnt changed the title Added axis parameter to lax.scan Add axis parameter to lax.scan Oct 30, 2020
@alonfnt
Copy link
Contributor Author

alonfnt commented Nov 2, 2020

Hi @mattjj, could you show me some short snippets of user code that should work. That way I can extend them into unit tests and then check that everything is working, since Jax is a large project and I am having a hard time thinking of tests.

I am eager to get this properly tested and merged now that I can focus on it!

@mattjj mattjj mentioned this pull request Dec 19, 2020
@hamzamerzic
Copy link
Contributor

Just checking in what the latest on this PR is @mattjj. I have an open TODO to include this in Haiku link. Thanks!

@hamzamerzic
Copy link
Contributor

Is there anything blocking here? Would be great to get this one in. Happy to take it on myself. Cheers.

@mattjj
Copy link
Collaborator

mattjj commented Jun 17, 2021

Thanks for the pings. Sorry for the 6+ month delay time!

The blocker is adding thorough tests. I think the ideal tests would cover cases like scanning over different axis indices (maybe making the current test a parametrized test?), autodiff, and vmapping. I'm especially concerned about vmapping, because I suspect the _scan_batching_rule function may still assume that the scanned-over axis is the leading axis (though that's just a suspicion!). The file lax_vmap_test.py has some examples for how to write thorough vmap tests, which cover lots of combinations of axes for different arguments.

Does that make sense?

@hamzamerzic
Copy link
Contributor

Oh, I just realized I completely missed your message. That makes sense! I'll see if I can take a look at this next week. Thanks!

@alonfnt
Copy link
Contributor Author

alonfnt commented Aug 7, 2021

@hamzamerzic As it has been a while since I opened this PR, I cleaned a bit the commits and parametrized the current tests as Mathew suggested. Everything is currently working for the normal implementation, which is what I needed, but when I tried to parametrize the tests that deal with batching such as testScanLinearize and testScanGrad, I did get failing results which none of my attempts were able to fix. It seems that the reason is, as Mathews correctly guessed, the _scan_batching_rule function and its assumption of leading axis.

You are more than welcome to continue this PR or throw it all and start from scratch, whichever suits you better 😄

@alonfnt alonfnt closed this Aug 13, 2022
@google-cla google-cla bot added cla: no and removed cla: yes labels Aug 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Expand scan to take in_axes and out_axes parameters
4 participants