Skip to content

Commit

Permalink
Split test_shared into test_model & test_sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
Armavica committed Sep 16, 2022
1 parent 2d9ffa8 commit c5805cd
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 65 deletions.
1 change: 0 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ jobs:
- |
pymc/tests/tuning/test_scaling.py
pymc/tests/tuning/test_starting.py
pymc/tests/test_shared.py
pymc/tests/test_sampling.py
pymc/tests/distributions/test_dist_math.py
pymc/tests/distributions/test_transform.py
Expand Down
11 changes: 11 additions & 0 deletions pymc/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,3 +1373,14 @@ def test_missing_symmetric():
logp_inputs = list(graph_inputs([logp]))
assert x_obs_vv in logp_inputs
assert x_unobs_vv in logp_inputs


class TestShared(SeededTest):
def test_deterministic(self):
with pm.Model() as model:
data_values = np.array([0.5, 0.4, 5, 2])
X = aesara.shared(np.asarray(data_values, dtype=aesara.config.floatX), borrow=True)
pm.Normal("y", 0, 1, observed=X)
assert np.all(
np.isclose(model.compile_logp(sum=False)({}), st.norm().logpdf(data_values))
)
34 changes: 34 additions & 0 deletions pymc/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2302,3 +2302,37 @@ def test_float32(self):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
pm.sample(draws=10, tune=10, chains=1, step=sampler())


class TestShared(SeededTest):
def test_sample(self):
x = np.random.normal(size=100)
y = x + np.random.normal(scale=1e-2, size=100)

x_pred = np.linspace(-3, 3, 200)

x_shared = aesara.shared(x)

with pm.Model() as model:
b = pm.Normal("b", 0.0, 10.0)
pm.Normal("obs", b * x_shared, np.sqrt(1e-2), observed=y, shape=x_shared.shape)
prior_trace0 = pm.sample_prior_predictive(1000)

idata = pm.sample(1000, tune=1000, chains=1)
pp_trace0 = pm.sample_posterior_predictive(idata)

x_shared.set_value(x_pred)
prior_trace1 = pm.sample_prior_predictive(1000)
pp_trace1 = pm.sample_posterior_predictive(idata)

assert prior_trace0.prior["b"].shape == (1, 1000)
assert prior_trace0.prior_predictive["obs"].shape == (1, 1000, 100)
np.testing.assert_allclose(
x, pp_trace0.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1
)

assert prior_trace1.prior["b"].shape == (1, 1000)
assert prior_trace1.prior_predictive["obs"].shape == (1, 1000, 200)
np.testing.assert_allclose(
x_pred, pp_trace1.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1
)
64 changes: 0 additions & 64 deletions pymc/tests/test_shared.py

This file was deleted.

0 comments on commit c5805cd

Please sign in to comment.