-
-
Notifications
You must be signed in to change notification settings - Fork 153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Subtensor
static shape inference in Subtensor.make_node
#935
base: main
Are you sure you want to change the base?
Adding Subtensor
static shape inference in Subtensor.make_node
#935
Conversation
237bda3
to
96f87f2
Compare
1848002
to
14a40b9
Compare
N.B.: Local testing yields |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error in CI says that the rewrite local_subtensor_make_vector
introduced a type error when it tried to replace a variable with another variable that has an incompatible static shape.
Inspection via the debugger shows that the node being rewritten has the following form:
Subtensor{int64::} [id BH] <TensorType(int64, (5,))> '' 4
| MakeVector{dtype='int64'} [id BD] <TensorType(int64, (3,))> '' 3 |
| ScalarConstant{-2} [id P] |
A shape of (5,)
for the output of a Subtensor
Op
like x[-2::]
, where x
is a vector of length 3, doesn't make much sense.
Let's see if we can reproduce the error by creating an equivalent graph.
import aesara
import aesara.tensor as at
# A vector of length three built from a `MakeVector` `Op`
x = at.as_tensor([at.lscalar(), at.lscalar(), at.lscalar()])
z = x[-2:]
aesara.dprint(z, print_type=True, depth=2)
# Subtensor{int64::} [id A] <TensorType(int64, (5,))> ''
# |MakeVector{dtype='int64'} [id B] <TensorType(int64, (3,))> ''
# |ScalarConstant{-2} [id C] <int64>
We've successfully reproduced the problematic graph, so we can now isolate the issue in Subtensor.make_node
.
Also, we now have a good unit test to add to tests.tensor.test_subtensor
. Since all the current tests in tests.tensor.test_subtensor
seem to pass when this bug is present, it's clear that those tests are lacking coverage for some simple and important cases. One of which is covered by the simple example above.
FYI: Simple test additions like these can be just as valuable to Aesara as the feature work from which they are derived.
14a40b9
to
1b91b97
Compare
Thanks for the detailed input. I see where the error is and will try to fix it. Maybe I can reuse Right now I am spending time addressing issues with testing on my Mac. I was running into graph compilation issues in |
1b91b97
to
5f39a37
Compare
Closes #922.
As per @brandonwillard's comment, I renamed
broadcastable
tostatic_shape
, incorporating the other suggested minor edits.I took inspiration from lines 773 to 777 from
Subtensor.infer_shape
for shape inference.No new tests have been added yet as I am creating this PR.