Skip to content

Commit

Permalink
Update several test xfails
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 authored and twiecki committed Jun 5, 2021
1 parent 0e9485f commit d91d649
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
20 changes: 10 additions & 10 deletions pymc3/tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_city_data():
return data.merge(unique, "inner", on="fips")


@pytest.mark.xfail(reason="Bernoulli distribution not refactored")
@pytest.mark.xfail(reason="Bernoulli logitp distribution not refactored")
class TestARM5_4(SeededTest):
def build_model(self):
data = pd.read_csv(
Expand Down Expand Up @@ -194,7 +194,7 @@ def build_disaster_model(masked=False):


@pytest.mark.xfail(
reason="DiscreteUniform hasn't been refactored"
reason="_check_start_shape fails with start dictionary"
# condition=(aesara.config.floatX == "float32"), reason="Fails on float32"
)
class TestDisasterModel(SeededTest):
Expand All @@ -204,9 +204,9 @@ def test_disaster_model(self):
model = build_disaster_model(masked=False)
with model:
# Initial values for stochastic nodes
start = {"early_mean": 2.0, "late_mean": 3.0}
start = {"early_mean": 2, "late_mean": 3.0}
# Use slice sampler for means (other variables auto-selected)
step = pm.Slice([model.early_mean_log__, model.late_mean_log__])
step = pm.Slice([model["early_mean_log__"], model["late_mean_log__"]])
tr = pm.sample(500, tune=50, start=start, step=step, chains=2)
az.summary(tr)

Expand All @@ -217,12 +217,12 @@ def test_disaster_model_missing(self):
# Initial values for stochastic nodes
start = {"early_mean": 2.0, "late_mean": 3.0}
# Use slice sampler for means (other variables auto-selected)
step = pm.Slice([model.early_mean_log__, model.late_mean_log__])
step = pm.Slice([model["early_mean_log__"], model["late_mean_log__"]])
tr = pm.sample(500, tune=50, start=start, step=step, chains=2)
az.summary(tr)


@pytest.mark.xfail(reason="ZeroInflatedPoisson hasn't been refactored for v4")
@pytest.mark.xfail(reason="_check_start_shape fails with start dictionary")
class TestLatentOccupancy(SeededTest):
"""
From the PyMC example list
Expand Down Expand Up @@ -277,14 +277,14 @@ def test_run(self):
"z": (self.y > 0).astype("int16"),
"theta": np.array(5, dtype="f"),
}
step_one = pm.Metropolis([model.theta_interval__, model.psi_logodds__])
step_one = pm.Metropolis([model["theta_interval__"], model["psi_logodds__"]])
step_two = pm.BinaryMetropolis([model.z])
pm.sample(50, step=[step_one, step_two], start=start, chains=1)


@pytest.mark.xfail(
# condition=(aesara.config.floatX == "float32"),
# reason="Fails on float32 due to starting inf at starting logP",
condition=(aesara.config.floatX == "float32"),
reason="Fails on float32 due to starting inf at starting logP",
)
class TestRSV(SeededTest):
"""
Expand Down Expand Up @@ -314,7 +314,7 @@ def build_model(self):
# Prior probability
prev_rsv = pm.Beta("prev_rsv", 1, 5, shape=3)
# RSV in Amman
y_amman = pm.Binomial("y_amman", n_amman, prev_rsv, shape=3, testval=100)
y_amman = pm.Binomial("y_amman", n_amman, prev_rsv, shape=3)
# Likelihood for number with RSV in hospital (assumes Pr(hosp | RSV) = 1)
pm.Binomial("y_hosp", y_amman, market_share, observed=rsv_cases)
return model
Expand Down
2 changes: 1 addition & 1 deletion pymc3/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def test_grad(self):
assert val == 21
npt.assert_allclose(grad, [5, 5, 5, 1, 1, 1, 1, 1, 1])

@pytest.mark.xfail(reason="Lognormal not refactored for v4")
@pytest.mark.xfail(reason="Test not refactored for v4")
def test_edge_case(self):
# Edge case discovered in #2948
ndim = 3
Expand Down
1 change: 0 additions & 1 deletion pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,6 @@ def test_shape_edgecase(self):
prior = pm.sample_prior_predictive(10)
assert prior["mu"].shape == (10, 5)

@pytest.mark.xfail(reason="ZeroInflatedPoisson not refactored for v4")
def test_zeroinflatedpoisson(self):
with pm.Model():
theta = pm.Beta("theta", alpha=1, beta=1)
Expand Down

0 comments on commit d91d649

Please sign in to comment.