Skip to content

Commit

Permalink
Ensure finite initial point in test_marginalized_transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 9, 2024
1 parent 1b09f82 commit 427ef18
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions tests/model/marginal/test_marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,15 +410,15 @@ def test_not_supported_marginalized_deterministic_and_potential():
(None, does_not_warn()),
(UNSET, does_not_warn()),
(transforms.log, does_not_warn()),
(transforms.Chain([transforms.log, transforms.logodds]), does_not_warn()),
(transforms.Chain([transforms.logodds, transforms.log]), does_not_warn()),
(
transforms.Interval(0, 1),
transforms.Interval(0, 2),
pytest.warns(
UserWarning, match="which depends on the marginalized idx may no longer work"
),
),
(
transforms.Chain([transforms.log, transforms.Interval(0, 1)]),
transforms.Chain([transforms.log, transforms.Interval(-1, 1)]),
pytest.warns(
UserWarning, match="which depends on the marginalized idx may no longer work"
),
Expand All @@ -428,7 +428,7 @@ def test_not_supported_marginalized_deterministic_and_potential():
def test_marginalized_transforms(transform, expected_warning):
w = [0.1, 0.3, 0.6]
data = [0, 5, 10]
initval = 0.5 # Value that will be negative on the unconstrained space
initval = 0.7 # Value that will be negative on the unconstrained space

with pm.Model() as m_ref:
sigma = pm.Mixture(
Expand Down Expand Up @@ -467,7 +467,7 @@ def test_marginalized_transforms(transform, expected_warning):
transform_name = "log"
else:
transform_name = transform.name
assert f"sigma_{transform_name}__" in ip
assert -np.inf < ip[f"sigma_{transform_name}__"] < 0.0
np.testing.assert_allclose(m.compile_logp()(ip), m_ref.compile_logp()(ip))


Expand Down

0 comments on commit 427ef18

Please sign in to comment.