-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Infer shape of advanced boolean indexing #329
Infer shape of advanced boolean indexing #329
Conversation
c798cc3
to
1f22144
Compare
22de1e6
to
d34dbbc
Compare
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #329 +/- ##
==========================================
+ Coverage 80.44% 80.46% +0.02%
==========================================
Files 156 156
Lines 45470 45515 +45
Branches 11136 11149 +13
==========================================
+ Hits 36578 36625 +47
- Misses 6687 6688 +1
+ Partials 2205 2202 -3
|
d34dbbc
to
43e8946
Compare
43e8946
to
4978a6d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs a rebase, but otherwise this looks good to me
We have static type shapes now
d2d6d27
to
2a30ce6
Compare
Rebased and added more tests, but did not change any functionality. |
2a30ce6
to
a3a07b0
Compare
tests/tensor/test_subtensor.py
Outdated
@@ -2439,25 +2446,86 @@ def test_AdvancedSubtensor_bool(self): | |||
n = dmatrix() | |||
n_val = np.arange(6).reshape((2, 3)) | |||
|
|||
# infer_shape is not implemented, but it should not crash | |||
# This case triggers a runtime broadcasting between the nonzero() shapes | |||
with pytest.raises(AssertionError, match="hello"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aseyboldt It also showed up here. I think we should allow runtime broadcasting for the shape graph.
The Op itself is happy to broadcast the indexes and won't raise, so the shape shouldn't either. The gradient issue is not a thing here because those inputs are Disconnected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I really understand yet, but...
We are talking about cases like this, right?
np.ones((3, 3))[np.zeros((1, 5), dtype=int), np.zeros((7, 1), dtype=int)]
which is the same as
np.ones((3, 3))[np.zeros((7, 5), dtype=int), np.zeros((7, 5), dtype=int)]
Right now, I guess, this broadcasts dynamically.
And the gradients aren't really a problem here, because the corresponding grad op is AdvancedIndexInc
, which broadcasts the same way.
So I think I agree that gradient-wise this doesn't really seem to be an issue. And not allowing dynamic broadcasting here would (I think) be a breaking change.
So far so good, I'm just a bit scared that this dynamic broadcasting might make it really tricky to write a good implementation of AdvancedSubtensor
et al. We never had one in the c backend I think, the numba backend doesn't support this (because numba itself doesn't), and the jax backend doesn't really support any dynamic shapes either.
Maybe the reason that nobody ever wrote a good implementation of this is that this is hard to do with dynamic broadcasting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think the reason is that all these flavors of advanced indexing are just too complex, broadcasting or not. Theano (and us) only have a Python implementation for the general case.
We can always restrict efficient implementations to a subset of cases like the AdvancedSubtensor1 does. In this case we probably won't specialize boolean indexes + other advanced indexes that require broadcast nonzero with those.
For the non-boolean case it's easier to know if broadcasting is going to be needed or not I think, because it depends on input shapes and not their values.
Also numpy itself is pivoting towards a more reasonable "outer indexing" by default: https://numpy.org/neps/nep-0021-advanced-indexing.html
I've seen some libraries (xarray, numba) are just skipping ahead to that instead of implementing all these edge cases (like the advanced groups being pushed to the front when they are non contiguous)
On a separate note, I checked numba and @kc611 has implemented some support for some forms of advanced indexing so we may be able to support more than what we did so far in Numba: numba/numba#8616
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reintroduced the ability to do runtime_broadcast
in the broadcast_shape_iter
as an optional flag.
This is only used for computing the shape of AdvancedIndexing
. If we disallow dynamic broadcasting of the multidimensional indexes in the future, than we won't need to worry about it for shape inference either.
I think that decision should be done on a separate PR.
Also excludes the following cases: 1. expand_dims via broadcasting 2. multi-dimensional integer indexing (could lead to duplicates which is inconsitent with the lifted RV graph)
a3a07b0
to
6ddadf1
Compare
This facilitates the logic in
local_subtensor_rv_lift
(where some bugs were also fixed)