From 0f0badefca94b8106cde244f9f248c875a736415 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Mon, 15 Nov 2021 13:37:59 -0600 Subject: [PATCH] Use SpecifyShape to track length of RandomVariable's size --- aesara/link/numba/dispatch/random.py | 2 +- aesara/tensor/random/utils.py | 5 +++++ tests/tensor/random/test_op.py | 22 ++++++++++++++++++++++ tests/tensor/random/test_opt.py | 6 +++++- 4 files changed, 33 insertions(+), 2 deletions(-) diff --git a/aesara/link/numba/dispatch/random.py b/aesara/link/numba/dispatch/random.py index 743dafbe1c..6b932df5d0 100644 --- a/aesara/link/numba/dispatch/random.py +++ b/aesara/link/numba/dispatch/random.py @@ -37,7 +37,7 @@ def make_numba_random_fn(node, np_random_func): argument to the Numba-supported scalar ``np.random`` functions. """ - tuple_size = get_vector_length(node.inputs[1]) + tuple_size = int(get_vector_length(node.inputs[1])) size_dims = tuple_size - max(i.ndim for i in node.inputs[3:]) # Make a broadcast-capable version of the Numba supported scalar sampling diff --git a/aesara/tensor/random/utils.py b/aesara/tensor/random/utils.py index fa462be6c5..45b10bca67 100644 --- a/aesara/tensor/random/utils.py +++ b/aesara/tensor/random/utils.py @@ -6,9 +6,11 @@ from aesara.compile.sharedvalue import shared from aesara.graph.basic import Variable +from aesara.tensor import get_vector_length from aesara.tensor.basic import as_tensor_variable, cast, constant from aesara.tensor.extra_ops import broadcast_to from aesara.tensor.math import maximum +from aesara.tensor.shape import specify_shape from aesara.tensor.type import int_dtypes @@ -121,6 +123,9 @@ def normalize_size_param(size): ) else: size = cast(as_tensor_variable(size, ndim=1), "int64") + # This should help ensure that the length of `size` will be available + # after certain types of cloning (e.g. the kind `Scan` performs) + size = specify_shape(size, (get_vector_length(size),)) assert size.dtype in int_dtypes diff --git a/tests/tensor/random/test_op.py b/tests/tensor/random/test_op.py index 8337c2cb87..236045def9 100644 --- a/tests/tensor/random/test_op.py +++ b/tests/tensor/random/test_op.py @@ -7,6 +7,7 @@ from aesara.gradient import NullTypeGradError, grad from aesara.tensor.math import eq from aesara.tensor.random.op import RandomVariable, default_shape_from_params +from aesara.tensor.shape import specify_shape from aesara.tensor.type import all_dtypes, iscalar, tensor @@ -139,6 +140,27 @@ def test_RandomVariable_bcast(): assert res.broadcastable == (True,) +def test_RandomVariable_bcast_specify_shape(): + rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True) + + s1 = aet.as_tensor(1, dtype=np.int64) + s2 = iscalar() + s2.tag.test_value = 2 + s3 = iscalar() + s3.tag.test_value = 3 + s3 = Assert("testing")(s3, eq(s1, 1)) + + size = specify_shape(aet.as_tensor([s1, s3, s2, s2, s1]), (5,)) + mu = tensor(config.floatX, [False, False, True]) + mu.tag.test_value = np.random.normal(size=(2, 2, 1)).astype(config.floatX) + + std = tensor(config.floatX, [False, True, True]) + std.tag.test_value = np.ones((2, 1, 1)).astype(config.floatX) + + res = rv(mu, std, size=size) + assert res.broadcastable == (True, False, False, False, True) + + def test_RandomVariable_floatX(): test_rv_op = RandomVariable( "normal", diff --git a/tests/tensor/random/test_opt.py b/tests/tensor/random/test_opt.py index dd2ee9e8ce..5f468131bd 100644 --- a/tests/tensor/random/test_opt.py +++ b/tests/tensor/random/test_opt.py @@ -23,6 +23,7 @@ local_rv_size_lift, local_subtensor_rv_lift, ) +from aesara.tensor.shape import SpecifyShape from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor from aesara.tensor.type import iscalar, vector @@ -81,8 +82,11 @@ def test_inplace_optimization(): assert new_out.owner.op.inplace is True assert all( np.array_equal(a.data, b.data) - for a, b in zip(new_out.owner.inputs[1:], out.owner.inputs[1:]) + for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:]) ) + # A `SpecifyShape` is added + assert isinstance(new_out.owner.inputs[1].owner.op, SpecifyShape) + assert new_out.owner.inputs[1].owner.inputs[0].equals(out.owner.inputs[1]) @config.change_flags(compute_test_value="raise")