Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer shape of advanced boolean indexing #329

Merged
merged 3 commits into from
Jul 13, 2023

Conversation

ricardoV94
Copy link
Member

This facilitates the logic in local_subtensor_rv_lift (where some bugs were also fixed)

@ricardoV94 ricardoV94 force-pushed the fix_subtensor_rv_lift_bool_bug branch 2 times, most recently from c798cc3 to 1f22144 Compare June 6, 2023 16:02
@ricardoV94 ricardoV94 force-pushed the fix_subtensor_rv_lift_bool_bug branch 2 times, most recently from 22de1e6 to d34dbbc Compare June 7, 2023 07:20
@codecov-commenter
Copy link

codecov-commenter commented Jun 7, 2023

Codecov Report

Merging #329 (6ddadf1) into main (7218431) will increase coverage by 0.02%.
The diff coverage is 95.78%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #329      +/-   ##
==========================================
+ Coverage   80.44%   80.46%   +0.02%     
==========================================
  Files         156      156              
  Lines       45470    45515      +45     
  Branches    11136    11149      +13     
==========================================
+ Hits        36578    36625      +47     
- Misses       6687     6688       +1     
+ Partials     2205     2202       -3     
Impacted Files Coverage Δ
pytensor/tensor/rewriting/shape.py 81.01% <ø> (-0.19%) ⬇️
pytensor/tensor/random/rewriting/basic.py 93.58% <92.72%> (-0.31%) ⬇️
pytensor/tensor/extra_ops.py 89.17% <100.00%> (+0.17%) ⬆️
pytensor/tensor/subtensor.py 89.69% <100.00%> (+0.09%) ⬆️

... and 2 files with indirect coverage changes

@ricardoV94 ricardoV94 force-pushed the fix_subtensor_rv_lift_bool_bug branch from d34dbbc to 43e8946 Compare June 7, 2023 09:52
aseyboldt
aseyboldt previously approved these changes Jul 12, 2023
Copy link
Member

@aseyboldt aseyboldt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs a rebase, but otherwise this looks good to me

We have static type shapes now
@ricardoV94 ricardoV94 force-pushed the fix_subtensor_rv_lift_bool_bug branch 2 times, most recently from d2d6d27 to 2a30ce6 Compare July 12, 2023 08:53
@ricardoV94
Copy link
Member Author

Rebased and added more tests, but did not change any functionality.

@ricardoV94 ricardoV94 force-pushed the fix_subtensor_rv_lift_bool_bug branch from 2a30ce6 to a3a07b0 Compare July 12, 2023 15:24
@ricardoV94 ricardoV94 marked this pull request as draft July 12, 2023 15:25
@@ -2439,25 +2446,86 @@ def test_AdvancedSubtensor_bool(self):
n = dmatrix()
n_val = np.arange(6).reshape((2, 3))

# infer_shape is not implemented, but it should not crash
# This case triggers a runtime broadcasting between the nonzero() shapes
with pytest.raises(AssertionError, match="hello"):
Copy link
Member Author

@ricardoV94 ricardoV94 Jul 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aseyboldt It also showed up here. I think we should allow runtime broadcasting for the shape graph.

The Op itself is happy to broadcast the indexes and won't raise, so the shape shouldn't either. The gradient issue is not a thing here because those inputs are Disconnected.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I really understand yet, but...
We are talking about cases like this, right?

np.ones((3, 3))[np.zeros((1, 5), dtype=int), np.zeros((7, 1), dtype=int)]

which is the same as

np.ones((3, 3))[np.zeros((7, 5), dtype=int), np.zeros((7, 5), dtype=int)]

Right now, I guess, this broadcasts dynamically.
And the gradients aren't really a problem here, because the corresponding grad op is AdvancedIndexInc, which broadcasts the same way.

So I think I agree that gradient-wise this doesn't really seem to be an issue. And not allowing dynamic broadcasting here would (I think) be a breaking change.

So far so good, I'm just a bit scared that this dynamic broadcasting might make it really tricky to write a good implementation of AdvancedSubtensor et al. We never had one in the c backend I think, the numba backend doesn't support this (because numba itself doesn't), and the jax backend doesn't really support any dynamic shapes either.

Maybe the reason that nobody ever wrote a good implementation of this is that this is hard to do with dynamic broadcasting?

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think the reason is that all these flavors of advanced indexing are just too complex, broadcasting or not. Theano (and us) only have a Python implementation for the general case.

We can always restrict efficient implementations to a subset of cases like the AdvancedSubtensor1 does. In this case we probably won't specialize boolean indexes + other advanced indexes that require broadcast nonzero with those.

For the non-boolean case it's easier to know if broadcasting is going to be needed or not I think, because it depends on input shapes and not their values.

Also numpy itself is pivoting towards a more reasonable "outer indexing" by default: https://numpy.org/neps/nep-0021-advanced-indexing.html

I've seen some libraries (xarray, numba) are just skipping ahead to that instead of implementing all these edge cases (like the advanced groups being pushed to the front when they are non contiguous)

On a separate note, I checked numba and @kc611 has implemented some support for some forms of advanced indexing so we may be able to support more than what we did so far in Numba: numba/numba#8616

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reintroduced the ability to do runtime_broadcast in the broadcast_shape_iter as an optional flag.
This is only used for computing the shape of AdvancedIndexing. If we disallow dynamic broadcasting of the multidimensional indexes in the future, than we won't need to worry about it for shape inference either.

I think that decision should be done on a separate PR.

Also excludes the following cases:
1. expand_dims via broadcasting
2. multi-dimensional integer indexing (could lead to duplicates which is inconsitent with the lifted RV graph)
@ricardoV94 ricardoV94 force-pushed the fix_subtensor_rv_lift_bool_bug branch from a3a07b0 to 6ddadf1 Compare July 13, 2023 09:50
@ricardoV94 ricardoV94 marked this pull request as ready for review July 13, 2023 09:54
pytensor/tensor/extra_ops.py Show resolved Hide resolved
pytensor/tensor/extra_ops.py Show resolved Hide resolved
pytensor/tensor/extra_ops.py Show resolved Hide resolved
@ricardoV94 ricardoV94 merged commit 7a82a3f into pymc-devs:main Jul 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request graph rewriting shape problem
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants