Skip to content

Commit

Permalink
Set tests to error on uncaught warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard authored and rlouf committed Nov 21, 2022
1 parent eb55106 commit 9a1db19
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 24 deletions.
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ convention = numpy
[tool:pytest]
python_files=test*.py
testpaths=tests
filterwarnings =
error
ignore:::numdifftools

[coverage:run]
omit =
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_bernoulli_cumsum(size, axis):
def test_destructive_cumsum_fails():
"""Test that a cumsum that mixes dimensions fails"""
x_rv = at.random.normal(size=(2, 2, 2)).cumsum()
with pytest.raises(RuntimeError, match="could not be derived"):
with pytest.raises(UserWarning, match="Found a random variable that is not"):
joint_logprob(x_rv)


Expand Down
7 changes: 5 additions & 2 deletions tests/test_logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,8 +497,11 @@ def test_vonmises_logprob(dist_params, obs, size, error):
def scipy_logprob(obs, mu, kappa):
return stats.vonmises.logpdf(obs, kappa, loc=mu)

with cm:
scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logprob)
with pytest.raises(
UserWarning, match="The Op i0 does not provide a C implementation"
):
with cm:
scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logprob)


@pytest.mark.parametrize(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def create_mix_model(size, axis):
M_rv = env["M_rv"]

with pytest.raises(
RuntimeError,
match="The logprob terms of the following random variables could not be derived: {M}",
UserWarning,
match="Found a random variable that is not",
):
conditional_logprob(M_rv, I_rv, X_rv)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,5 +265,5 @@ def test_unmeargeable_dimshuffles():
w = z.dimshuffle((1, 0, 2))

# TODO: Check that logp is correct if this type of graphs is ever supported
with pytest.raises(RuntimeError, match="could not be derived"):
with pytest.raises(UserWarning, match="Found a random variable that is not"):
joint_logprob(w)
32 changes: 14 additions & 18 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,16 +463,12 @@ def test_mixture_transform():

transform_rewrite = TransformValuesRewrite({Y_rv: LogOddsTransform()})

with pytest.warns(None) as record:
# This shouldn't raise any warnings
logp_trans, (y_vv_trans, i_vv_trans) = joint_logprob(
Y_rv,
I_rv,
extra_rewrites=transform_rewrite,
use_jacobian=False,
)

assert not record.list
logp_trans, (y_vv_trans, i_vv_trans) = joint_logprob(
Y_rv,
I_rv,
extra_rewrites=transform_rewrite,
use_jacobian=False,
)

logp_fn = aesara.function((i_vv, y_vv), logp)
logp_trans_fn = aesara.function((i_vv_trans, y_vv_trans), logp_trans)
Expand Down Expand Up @@ -588,8 +584,8 @@ def test_loc_transform_rv(rv_size, loc_type):
assert_no_rvs(logp)
logp_fn = aesara.function([loc, y_vv], logp)

loc_test_val = np.full(rv_size, 4.0)
y_test_val = np.full(rv_size, 1.0)
loc_test_val = np.full(rv_size or (), 4.0)
y_test_val = np.full(rv_size or (), 1.0)

np.testing.assert_allclose(
logp_fn(loc_test_val, y_test_val),
Expand All @@ -616,8 +612,8 @@ def test_scale_transform_rv(rv_size, scale_type):
assert_no_rvs(logp)
logp_fn = aesara.function([scale, y_vv], logp)

scale_test_val = np.full(rv_size, 4.0)
y_test_val = np.full(rv_size, 1.0)
scale_test_val = np.full(rv_size or (), 4.0)
y_test_val = np.full(rv_size or (), 1.0)

np.testing.assert_allclose(
logp_fn(scale_test_val, y_test_val),
Expand Down Expand Up @@ -648,7 +644,7 @@ def test_loc_transform_multiple_rvs_fails1():
x_rv2 = at.random.normal(name="x_rv2")
y_rv = x_rv1 + x_rv2

with pytest.raises(RuntimeError, match="could not be derived"):
with pytest.raises(UserWarning, match="Found a random variable that is not"):
joint_logprob(y_rv)


Expand All @@ -657,19 +653,19 @@ def test_nested_loc_transform_multiple_rvs_fails2():
x_rv2 = at.cos(at.random.normal(name="x_rv2"))
y_rv = x_rv1 + x_rv2

with pytest.raises(RuntimeError, match="could not be derived"):
with pytest.raises(UserWarning, match="Found a random variable that is not"):
joint_logprob(y_rv)


def test_discrete_rv_unary_transform_fails():
y_rv = at.exp(at.random.poisson(1))
with pytest.raises(RuntimeError, match="could not be derived"):
with pytest.raises(UserWarning, match="Found a random variable that is not"):
joint_logprob(y_rv)


def test_discrete_rv_multinary_transform_fails():
y_rv = 5 + at.random.poisson(1)
with pytest.raises(RuntimeError, match="could not be derived"):
with pytest.raises(UserWarning, match="Found a random variable that is not"):
joint_logprob(y_rv)


Expand Down

0 comments on commit 9a1db19

Please sign in to comment.