Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The shape of the RandomVariables is hard to reason about #1252

Open
rlouf opened this issue Oct 12, 2022 · 5 comments
Open

The shape of the RandomVariables is hard to reason about #1252

rlouf opened this issue Oct 12, 2022 · 5 comments
Labels
bug Something isn't working important random variables Involves random variables and/or sampling shape inference

Comments

@rlouf
Copy link
Member

rlouf commented Oct 12, 2022

The RandomVariables intialized with different values for size have a concrete shape when the size is broadcastable, and symbolic otherwise:

import aesara.tensor as at

srng = at.random.randomstream(0)

a_rv = srng.normal(0, 1)
# breakpoint inside `make_node`
# print(size, shape, bcast)
# tensorconstant{[]} tensorconstant{[]} ()
print(a_rv.shape)
# tensorconstant{[]}

a_rv = srng.normal(0, 1, size=(0,))
# breakpoint inside `make_node`
# print(size, shape, bcast)
# tensorconstant{(1,) of 0} tensorconstant{(1,) of 0} (false,)
print(a_rv.shape)
# shape.0


a_rv = srng.normal(0, 1, size=(1,))
# breakpoint inside `make_node`
# print(size, shape, bcast)
# tensorconstant{(1,) of 1} tensorconstant{(1,) of 1} (true,)
print(a_rv.shape)
# tensorconstant{(1,) of 1}


a_rv = srng.normal(0, 1, size=(2,))
# breakpoint inside `make_node`
# print(size, shape, bcast)
# tensorconstant{(1,) of 2} tensorconstant{(1,) of 2} (false,)
print(a_rv.shape)
# shape.0

This is problematic for some downstream applications. In AeHMC we create a scalar/vector inverse_mass_matrix parameter based on the number of dimensions of the variable, which thus does not have a definite shape. When a_rv.shape is a TensorConstant the shape checks performed in Scan will raise an error.

I would expect all of the above shapes to be TensorConstants, or at the very least to be all symbolic. The shapes are determined when computing the output type with TensorType(dtype=dtype, shape=bcast).

I have marked this as a question since I might also be initializing the inverse_mass_matrix incorrectly (as at.scalar or at.vector without specifying the shape).

@rlouf rlouf added question Further information is requested random variables Involves random variables and/or sampling labels Oct 12, 2022
@ricardoV94
Copy link
Contributor

ricardoV94 commented Oct 12, 2022

We can output better static shape info beyond broadcastable/not but otherwise I don't se a problem.

You should be able to use at.specify_shape(input, rv.shape) if you are having issues with replacing RVs by values or something of that sort. Or preemptively constant fold shapes if you need them.

Also you might want to use var.type.shape instead of var.shape depending on what you are trying to achieve.

@rlouf rlouf added bug Something isn't working and removed question Further information is requested labels Oct 12, 2022
@rlouf
Copy link
Member Author

rlouf commented Oct 12, 2022

Turning this into a bug after an offline discussion with @brandonwillard, a_rv.shape should return a TensorConstant in every example given in the original comment.

@brandonwillard
Copy link
Member

Turning this into a bug after an offline discussion with @brandonwillard, a_rv.shape should return a TensorConstant in every example given in the original comment.

More specifically, it looks like aesara.tensor.basic.infer_broadcastable should be returning any and all constant shapes produced by the relatively costly rewrite_graph call, instead of discarding that information via the deprecated broadcastable encoding.

@rlouf
Copy link
Member Author

rlouf commented Oct 13, 2022

Indeed, if I insert a breakpoint right before

bcast = tuple(getattr(s, "data", s) == 1 for s in folded_shape)
folded_shape contains the correct type information (constant shapes for the cases size=(0,) and size=(2,) in the original post).

@brandonwillard
Copy link
Member

Indeed, if I insert a breakpoint right before

bcast = tuple(getattr(s, "data", s) == 1 for s in folded_shape)

folded_shape contains the correct type information (constant shapes for the cases size=(0,) and size=(2,) in the original post).

Yeah, I have a fix in #1253, but I should split that off from some of the other experimental stuff in that PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working important random variables Involves random variables and/or sampling shape inference
Projects
None yet
Development

No branches or pull requests

3 participants