Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with sampling of PartialObservedRVs #7071

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,7 +1337,8 @@ def create_partial_observed_rv(
ndim_supp=rv.owner.op.ndim_supp,
)(rv, mask)

joined_rv = pt.empty(rv.shape, dtype=rv.type.dtype)
[rv_shape] = constant_fold([rv.shape], raise_not_constant=False)
joined_rv = pt.empty(rv_shape, dtype=rv.type.dtype)
joined_rv = pt.set_subtensor(joined_rv[mask], unobserved_rv)
joined_rv = pt.set_subtensor(joined_rv[antimask], observed_rv)

Expand Down
4 changes: 2 additions & 2 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
RandomGeneratorSharedVariable,
RandomStateSharedVariable,
)
from pytensor.tensor.rewriting.basic import topo_constant_folding
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
Expand Down Expand Up @@ -1015,7 +1014,8 @@ def constant_fold(
"""
fg = FunctionGraph(outputs=xs, features=[ShapeFeature()], clone=True)

folded_xs = rewrite_graph(fg, custom_rewrite=topo_constant_folding).outputs
# By default, rewrite_graph includes canonicalize which includes constant-folding as the final rewrite
folded_xs = rewrite_graph(fg).outputs

if raise_not_constant and not all(isinstance(folded_x, Constant) for folded_x in folded_xs):
raise NotConstantValueError
Expand Down
14 changes: 14 additions & 0 deletions tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,3 +1061,17 @@ def test_wrong_mask(self):
invalid_mask = np.zeros((1, 5), dtype=bool)
with pytest.raises(ValueError, match="mask can't have more dims than rv"):
create_partial_observed_rv(rv, invalid_mask)

@pytest.mark.filterwarnings("error")
def test_default_updates(self):
mask = np.array([True, True, False])
rv = pm.Normal.dist(shape=(3,))
(obs_rv, _), (unobs_rv, _), joined_rv = create_partial_observed_rv(rv, mask)

draws_obs_rv, draws_unobs_rv, draws_joined_rv = pm.draw(
[obs_rv, unobs_rv, joined_rv], draws=2
)

assert np.all(draws_obs_rv[0] != draws_obs_rv[1])
assert np.all(draws_unobs_rv[0] != draws_unobs_rv[1])
assert np.all(draws_joined_rv[0] != draws_joined_rv[1])
12 changes: 12 additions & 0 deletions tests/sampling/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import pymc as pm

from pymc import ImputationWarning
from pymc.distributions.multivariate import PosDefMatrix
from pymc.sampling.jax import (
_get_batched_jittered_initial_points,
Expand Down Expand Up @@ -459,3 +460,14 @@ def test_idata_contains_stats(sampler_name: str):
for stat_var, stat_var_dims in stat_vars.items():
assert stat_var in stats.variables
assert stats.get(stat_var).values.shape == stat_var_dims


def test_sample_partially_observed():
with pm.Model() as m:
with pytest.warns(ImputationWarning):
x = pm.Normal("x", observed=np.array([0, 1, np.nan]))
idata = pm.sample(nuts_sampler="numpyro", chains=1, draws=10, tune=10)

assert idata.observed_data["x_observed"].shape == (2,)
assert idata.posterior["x_unobserved"].shape == (1, 10, 1)
assert idata.posterior["x"].shape == (1, 10, 3)
Loading