Skip to content

Commit

Permalink
Use SpecifyShape to track length of RandomVariable's size
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Nov 16, 2021
1 parent 9f7d406 commit 0f0bade
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 2 deletions.
2 changes: 1 addition & 1 deletion aesara/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def make_numba_random_fn(node, np_random_func):
argument to the Numba-supported scalar ``np.random`` functions.
"""

tuple_size = get_vector_length(node.inputs[1])
tuple_size = int(get_vector_length(node.inputs[1]))
size_dims = tuple_size - max(i.ndim for i in node.inputs[3:])

# Make a broadcast-capable version of the Numba supported scalar sampling
Expand Down
5 changes: 5 additions & 0 deletions aesara/tensor/random/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

from aesara.compile.sharedvalue import shared
from aesara.graph.basic import Variable
from aesara.tensor import get_vector_length
from aesara.tensor.basic import as_tensor_variable, cast, constant
from aesara.tensor.extra_ops import broadcast_to
from aesara.tensor.math import maximum
from aesara.tensor.shape import specify_shape
from aesara.tensor.type import int_dtypes


Expand Down Expand Up @@ -121,6 +123,9 @@ def normalize_size_param(size):
)
else:
size = cast(as_tensor_variable(size, ndim=1), "int64")
# This should help ensure that the length of `size` will be available
# after certain types of cloning (e.g. the kind `Scan` performs)
size = specify_shape(size, (get_vector_length(size),))

assert size.dtype in int_dtypes

Expand Down
22 changes: 22 additions & 0 deletions tests/tensor/random/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from aesara.gradient import NullTypeGradError, grad
from aesara.tensor.math import eq
from aesara.tensor.random.op import RandomVariable, default_shape_from_params
from aesara.tensor.shape import specify_shape
from aesara.tensor.type import all_dtypes, iscalar, tensor


Expand Down Expand Up @@ -139,6 +140,27 @@ def test_RandomVariable_bcast():
assert res.broadcastable == (True,)


def test_RandomVariable_bcast_specify_shape():
rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)

s1 = aet.as_tensor(1, dtype=np.int64)
s2 = iscalar()
s2.tag.test_value = 2
s3 = iscalar()
s3.tag.test_value = 3
s3 = Assert("testing")(s3, eq(s1, 1))

size = specify_shape(aet.as_tensor([s1, s3, s2, s2, s1]), (5,))
mu = tensor(config.floatX, [False, False, True])
mu.tag.test_value = np.random.normal(size=(2, 2, 1)).astype(config.floatX)

std = tensor(config.floatX, [False, True, True])
std.tag.test_value = np.ones((2, 1, 1)).astype(config.floatX)

res = rv(mu, std, size=size)
assert res.broadcastable == (True, False, False, False, True)


def test_RandomVariable_floatX():
test_rv_op = RandomVariable(
"normal",
Expand Down
6 changes: 5 additions & 1 deletion tests/tensor/random/test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
local_rv_size_lift,
local_subtensor_rv_lift,
)
from aesara.tensor.shape import SpecifyShape
from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor
from aesara.tensor.type import iscalar, vector

Expand Down Expand Up @@ -81,8 +82,11 @@ def test_inplace_optimization():
assert new_out.owner.op.inplace is True
assert all(
np.array_equal(a.data, b.data)
for a, b in zip(new_out.owner.inputs[1:], out.owner.inputs[1:])
for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:])
)
# A `SpecifyShape` is added
assert isinstance(new_out.owner.inputs[1].owner.op, SpecifyShape)
assert new_out.owner.inputs[1].owner.inputs[0].equals(out.owner.inputs[1])


@config.change_flags(compute_test_value="raise")
Expand Down

0 comments on commit 0f0bade

Please sign in to comment.