Skip to content

Commit

Permalink
Add test to compare draws when var_names is used in pm.sample
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto committed Apr 28, 2024
1 parent a74c03f commit 512cb01
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions tests/sampling/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,42 @@ def test_sample_var_names():
assert "b" not in idata.posterior


def test_sample_var_names_draws():
# Generate data
seed = 1234
rng = np.random.default_rng(seed)

group = rng.choice(list("ABCD"), size=100)
x = rng.normal(size=100)
y = rng.normal(size=100)

group_values, group_idx = np.unique(group, return_inverse=True)

coords = {"group": group_values}

# Create model
with pm.Model(coords=coords) as model:
b_group = pm.Normal("b_group", dims="group")
b_x = pm.Normal("b_x")
mu = pm.Deterministic("mu", b_group[group_idx] + b_x * x)
sigma = pm.HalfNormal("sigma")
pm.Normal("y", mu=mu, sigma=sigma, observed=y)

# Sample with and without var_names, but always with the same seed
with model:
idata_1 = pm.sample(tune=100, draws=100, random_seed=seed)
idata_2 = pm.sample(
tune=100, draws=100, var_names=["b_group", "b_x", "sigma"], random_seed=seed
)

assert "mu" in idata_1.posterior
assert "mu" not in idata_2.posterior

assert np.all(idata_1.posterior["b_group"] == idata_2.posterior["b_group"]).item()
assert np.all(idata_1.posterior["b_x"] == idata_2.posterior["b_x"]).item()
assert np.all(idata_1.posterior["sigma"] == idata_2.posterior["sigma"]).item()


class TestAssignStepMethods:
def test_bernoulli(self):
"""Test bernoulli distribution is assigned binary gibbs metropolis method"""
Expand Down

0 comments on commit 512cb01

Please sign in to comment.