Skip to content

Commit

Permalink
Make sure new size values are int64
Browse files Browse the repository at this point in the history
Closes #4652.
  • Loading branch information
brandonwillard authored and twiecki committed Jun 5, 2021
1 parent 26a5787 commit 623d6a3
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
4 changes: 4 additions & 0 deletions pymc3/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def change_rv_size(
size = rv_node.op._infer_shape(size, dist_params)
new_size = tuple(np.atleast_1d(new_size)) + tuple(size)

# Make sure the new size is int64 so that it doesn't unnecessarily pick
# up a `Cast` in some cases
new_size = at.as_tensor(new_size, ndim=1, dtype="int64")

new_rv_node = rv_node.op.make_node(rng, new_size, dtype, *dist_params)
rv_var = new_rv_node.outputs[-1]
rv_var.name = name
Expand Down
11 changes: 10 additions & 1 deletion pymc3/tests/test_aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import pytest
import scipy.sparse as sps

from aesara.graph.basic import Variable, ancestors
from aesara.graph.basic import Constant, Variable, ancestors
from aesara.tensor.random.basic import normal, uniform
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
Expand Down Expand Up @@ -67,6 +67,15 @@ def test_change_rv_size():
assert rv_newer.ndim == 3
assert rv_newer.eval().shape == (4, 3, 2)

# Make sure we avoid introducing a `Cast` by converting the new size before
# constructing the new `RandomVariable`
rv = normal(0, 1)
new_size = np.array([4, 3], dtype="int32")
rv_newer = change_rv_size(rv, new_size=new_size, expand=False)
assert rv_newer.ndim == 2
assert isinstance(rv_newer.owner.inputs[1], Constant)
assert rv_newer.eval().shape == (4, 3)


class TestBroadcasting:
def test_make_shared_replacements(self):
Expand Down

0 comments on commit 623d6a3

Please sign in to comment.