From 117b40c9a3b421a7dccfd7c0ee04eab4a6f0716c Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Thu, 28 Oct 2021 19:24:08 -0500 Subject: [PATCH] Add a test for a shape inference issue between Scan and RandomVariable --- tests/scan/test_basic.py | 73 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index c4823678e7..72024e3de7 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -5036,3 +5036,76 @@ def accum(seq_t, prev_sum): _seq = np.arange(20).astype("int32") _sum = f(_seq) assert _sum == 2 + + +def test_inner_get_vector_length(): + """Make sure we can handle/preserve fixed shape terms when cloning the body of a `Scan`.""" + + rng_at = RandomStream() + + s1 = lscalar("s1") + s2 = lscalar("s2") + size_at = aet.as_tensor([s1, s2]) + + def scan_body(size): + # `size` will be cloned and replaced with an ownerless `TensorVariable`. + # This will cause `RandomVariable.infer_shape` to fail, because it expects + # `get_vector_length` to work on all `size` arguments. + return rng_at.normal(0, 1, size=size) + + res, _ = scan( + scan_body, + non_sequences=[size_at], + n_steps=10, + strict=True, + ) + + assert isinstance(res.owner.op, Scan) + + # Make sure the `size` in `scan_body` is a plain `Variable` instance + # carrying no information with which we can derive its length + size_clone = res.owner.op.inputs[1] + assert size_clone.owner is None + + # Make sure the cloned `size` maps to the original `size_at` + inner_outer_map = res.owner.op.get_oinp_iinp_iout_oout_mappings() + outer_input_idx = inner_outer_map["outer_inp_from_inner_inp"][1] + original_size = res.owner.inputs[outer_input_idx] + assert original_size == size_at + + with config.change_flags(on_opt_error="raise", on_shape_error="raise"): + res_fn = function([size_at], res.shape) + + assert np.array_equal(res_fn((1, 2)), (10, 1, 2)) + + # Second case has an empty size non-sequence + size_at = aet.as_tensor([], dtype=np.int64) + + res, _ = scan( + scan_body, + non_sequences=[size_at], + n_steps=10, + strict=True, + ) + + assert isinstance(res.owner.op, Scan) + with config.change_flags(on_opt_error="raise", on_shape_error="raise"): + res_fn = function([], res.shape) + + assert np.array_equal(res_fn(), (10,)) + + # Third case has a constant size non-sequence + size_at = aet.as_tensor([3], dtype=np.int64) + + res, _ = scan( + scan_body, + non_sequences=[size_at], + n_steps=10, + strict=True, + ) + + assert isinstance(res.owner.op, Scan) + with config.change_flags(on_opt_error="raise", on_shape_error="raise"): + res_fn = function([], res.shape) + + assert np.array_equal(res_fn(), (10, 3))