Skip to content

Commit

Permalink
Add a test for a shape inference issue between Scan and RandomVariable
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Nov 18, 2021
1 parent 3d4ef66 commit 117b40c
Showing 1 changed file with 73 additions and 0 deletions.
73 changes: 73 additions & 0 deletions tests/scan/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5036,3 +5036,76 @@ def accum(seq_t, prev_sum):
_seq = np.arange(20).astype("int32")
_sum = f(_seq)
assert _sum == 2


def test_inner_get_vector_length():
"""Make sure we can handle/preserve fixed shape terms when cloning the body of a `Scan`."""

rng_at = RandomStream()

s1 = lscalar("s1")
s2 = lscalar("s2")
size_at = aet.as_tensor([s1, s2])

def scan_body(size):
# `size` will be cloned and replaced with an ownerless `TensorVariable`.
# This will cause `RandomVariable.infer_shape` to fail, because it expects
# `get_vector_length` to work on all `size` arguments.
return rng_at.normal(0, 1, size=size)

res, _ = scan(
scan_body,
non_sequences=[size_at],
n_steps=10,
strict=True,
)

assert isinstance(res.owner.op, Scan)

# Make sure the `size` in `scan_body` is a plain `Variable` instance
# carrying no information with which we can derive its length
size_clone = res.owner.op.inputs[1]
assert size_clone.owner is None

# Make sure the cloned `size` maps to the original `size_at`
inner_outer_map = res.owner.op.get_oinp_iinp_iout_oout_mappings()
outer_input_idx = inner_outer_map["outer_inp_from_inner_inp"][1]
original_size = res.owner.inputs[outer_input_idx]
assert original_size == size_at

with config.change_flags(on_opt_error="raise", on_shape_error="raise"):
res_fn = function([size_at], res.shape)

assert np.array_equal(res_fn((1, 2)), (10, 1, 2))

# Second case has an empty size non-sequence
size_at = aet.as_tensor([], dtype=np.int64)

res, _ = scan(
scan_body,
non_sequences=[size_at],
n_steps=10,
strict=True,
)

assert isinstance(res.owner.op, Scan)
with config.change_flags(on_opt_error="raise", on_shape_error="raise"):
res_fn = function([], res.shape)

assert np.array_equal(res_fn(), (10,))

# Third case has a constant size non-sequence
size_at = aet.as_tensor([3], dtype=np.int64)

res, _ = scan(
scan_body,
non_sequences=[size_at],
n_steps=10,
strict=True,
)

assert isinstance(res.owner.op, Scan)
with config.change_flags(on_opt_error="raise", on_shape_error="raise"):
res_fn = function([], res.shape)

assert np.array_equal(res_fn(), (10, 3))

0 comments on commit 117b40c

Please sign in to comment.