Skip to content

Commit

Permalink
Add OS-independent #4652 regression tests
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelosthege committed May 17, 2021
1 parent b9695f4 commit 60347c0
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pymc3/tests/test_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys

import aesara
import numpy as np
Expand Down Expand Up @@ -168,6 +169,9 @@ def ode_func_5(y, t, p):
np.testing.assert_array_equal(np.ravel(model5_sens_ic), model5._sens_ic)


@pytest.mark.xfail(
condition=sys.platform == "win32", reason="https://github.com/pymc-devs/aesara/issues/390"
)
def test_logp_scalar_ode():
"""Test the computation of the log probability for these models"""

Expand Down
30 changes: 30 additions & 0 deletions pymc3/tests/test_shape_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import aesara
import numpy as np
import pytest

Expand Down Expand Up @@ -219,3 +220,32 @@ def test_sample_generate_values(fixture_model, fixture_sizes):
prior = pm.sample_prior_predictive(samples=fixture_sizes)
for rv in RVs:
assert prior[rv.name].shape == size + tuple(rv.distribution.shape)


@pytest.mark.xfail(reason="https://github.com/pymc-devs/aesara/issues/390")
def test_size32_doesnt_break_broadcasting():
size32 = at.constant([1, 10], dtype="int32")
rv = pm.Normal.dist(0, 1, size=size32)
assert rv.broadcastable == (True, False)


def test_observed_with_column_vector():
"""This test is related to https://github.com/pymc-devs/aesara/issues/390 which breaks
broadcastability of column-vector RVs. This unexpected change in type can lead to
incompatibilities during graph rewriting for model.logp evaluation.
"""
with pm.Model() as model:
# The `observed` is a broadcastable column vector
obs = at.as_tensor_variable(np.ones((3, 1), dtype=aesara.config.floatX))
assert obs.broadcastable == (False, True)

# Both shapes describe broadcastable volumn vectors
size64 = at.constant([3, 1], dtype="int64")
# But the second shape is upcasted from an int32 vector
cast64 = at.cast(at.constant([3, 1], dtype="int32"), dtype="int64")

pm.Normal("x_size64", mu=0, sd=1, size=size64, observed=obs)
model.logp()

pm.Normal("x_cast64", mu=0, sd=1, size=cast64, observed=obs)
model.logp()

0 comments on commit 60347c0

Please sign in to comment.