diff --git a/pymc/model.py b/pymc/model.py index 91813e81a9e..efb50b32bce 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -560,7 +560,6 @@ def __init__( self.rvs_to_initial_values = treedict(parent=self.parent.rvs_to_initial_values) self.free_RVs = treelist(parent=self.parent.free_RVs) self.observed_RVs = treelist(parent=self.parent.observed_RVs) - self.auto_deterministics = treelist(parent=self.parent.auto_deterministics) self.deterministics = treelist(parent=self.parent.deterministics) self.potentials = treelist(parent=self.parent.potentials) self._coords = self.parent._coords @@ -575,7 +574,6 @@ def __init__( self.rvs_to_initial_values = treedict() self.free_RVs = treelist() self.observed_RVs = treelist() - self.auto_deterministics = treelist() self.deterministics = treelist() self.potentials = treelist() self._coords = {} @@ -1435,10 +1433,11 @@ def make_obs_var( self.observed_RVs.append(observed_rv_var) # Create deterministic that combines observed and missing + # Note: This can widely increase memory consumption during sampling for large datasets rv_var = at.zeros(data.shape) rv_var = at.set_subtensor(rv_var[mask.nonzero()], missing_rv_var) rv_var = at.set_subtensor(rv_var[antimask_idx], observed_rv_var) - rv_var = Deterministic(name, rv_var, self, dims, auto=True) + rv_var = Deterministic(name, rv_var, self, dims) else: if sps.issparse(data): @@ -1911,7 +1910,7 @@ def Point(*args, filter_model_vars=False, **kwargs) -> Dict[str, np.ndarray]: } -def Deterministic(name, var, model=None, dims=None, auto=False): +def Deterministic(name, var, model=None, dims=None): """Create a named deterministic variable. Deterministic nodes are only deterministic given all of their inputs, i.e. @@ -1974,10 +1973,7 @@ def Deterministic(name, var, model=None, dims=None, auto=False): """ model = modelcontext(model) var = var.copy(model.name_for(name)) - if auto: - model.auto_deterministics.append(var) - else: - model.deterministics.append(var) + model.deterministics.append(var) model.add_named_variable(var, dims) from pymc.printing import str_for_potential_or_deterministic diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index ee5b2f304a2..53847b6cad9 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -35,7 +35,14 @@ import xarray from aesara import tensor as at -from aesara.graph.basic import Apply, Constant, Variable, general_toposort, walk +from aesara.graph.basic import ( + Apply, + Constant, + Variable, + ancestors, + general_toposort, + walk, +) from aesara.graph.fg import FunctionGraph from aesara.tensor.random.var import ( RandomGeneratorSharedVariable, @@ -324,6 +331,18 @@ def draw( return [np.stack(v) for v in drawn_values] +def observed_dependent_deterministics(model: Model): + """Find deterministics that depend directly on observed variables""" + deterministics = model.deterministics + observed_rvs = set(model.observed_RVs) + blockers = model.basic_RVs + return [ + deterministic + for deterministic in deterministics + if observed_rvs & set(ancestors([deterministic], blockers=blockers)) + ] + + def sample_prior_predictive( samples: int = 500, model: Optional[Model] = None, @@ -371,10 +390,7 @@ def sample_prior_predictive( ) if var_names is None: - vars_: Set[str] = { - var.name - for var in model.basic_RVs + model.deterministics + model.auto_deterministics - } + vars_: Set[str] = {var.name for var in model.basic_RVs + model.deterministics} else: vars_ = set(var_names) @@ -570,7 +586,7 @@ def sample_posterior_predictive( if var_names is not None: vars_ = [model[x] for x in var_names] else: - vars_ = model.observed_RVs + model.auto_deterministics + vars_ = model.observed_RVs + observed_dependent_deterministics(model) indices = np.arange(samples) if progressbar: diff --git a/pymc/tests/sampling/test_forward.py b/pymc/tests/sampling/test_forward.py index 447775bc01a..59063c6326a 100644 --- a/pymc/tests/sampling/test_forward.py +++ b/pymc/tests/sampling/test_forward.py @@ -38,6 +38,7 @@ from pymc.sampling.forward import ( compile_forward_sampling_function, get_vars_in_point_list, + observed_dependent_deterministics, ) from pymc.tests.helpers import SeededTest, fast_unstable_sampling_mode @@ -1621,3 +1622,19 @@ def test_get_vars_in_point_list(): trace = MultiTrace([strace]) vars_in_trace = get_vars_in_point_list(trace, modelB) assert set(vars_in_trace) == {a} + + +def test_observed_dependent_deterministics(): + with pm.Model() as m: + free = pm.Normal("free") + obs = pm.Normal("obs", observed=1) + + det_free = pm.Deterministic("det_free", free + 1) + det_free2 = pm.Deterministic("det_free2", det_free + 1) + + det_obs = pm.Deterministic("det_obs", obs + 1) + det_obs2 = pm.Deterministic("det_obs2", det_obs + 1) + + det_mixed = pm.Deterministic("det_mixed", free + obs) + + assert set(observed_dependent_deterministics(m)) == {det_obs, det_obs2, det_mixed} diff --git a/pymc/tests/test_model.py b/pymc/tests/test_model.py index 9f767db1f7d..e03cd4507b6 100644 --- a/pymc/tests/test_model.py +++ b/pymc/tests/test_model.py @@ -1195,21 +1195,29 @@ def test_missing_dual_observations(self): trace = pm.sample(chains=1, tune=5, draws=50) def test_interval_missing_observations(self): + rng = np.random.default_rng(1198) + with pm.Model() as model: obs1 = np.ma.masked_values([1, 2, -1, 4, -1], value=-1) obs2 = np.ma.masked_values([-1, -1, 6, -1, 8], value=-1) - rng = aesara.shared(np.random.RandomState(2323), borrow=True) - with pytest.warns(ImputationWarning): - theta1 = pm.Uniform("theta1", 0, 5, observed=obs1, rng=rng) + theta1 = pm.Uniform("theta1", 0, 5, observed=obs1) with pytest.warns(ImputationWarning): - theta2 = pm.Normal("theta2", mu=theta1, observed=obs2, rng=rng) + theta2 = pm.Normal("theta2", mu=theta1, observed=obs2) assert isinstance(model.rvs_to_transforms[model["theta1_missing"]], IntervalTransform) assert model.rvs_to_transforms[model["theta1_observed"]] is None - prior_trace = pm.sample_prior_predictive(return_inferencedata=False) + prior_trace = pm.sample_prior_predictive(random_seed=rng, return_inferencedata=False) + assert set(prior_trace.keys()) == { + "theta1", + "theta1_observed", + "theta1_missing", + "theta2", + "theta2_observed", + "theta2_missing", + } # Make sure the observed + missing combined deterministics have the # same shape as the original observations vectors @@ -1237,23 +1245,47 @@ def test_interval_missing_observations(self): == 0.0 ) - assert {"theta1", "theta2"} <= set(prior_trace.keys()) - trace = pm.sample( - chains=1, draws=50, compute_convergence_checks=False, return_inferencedata=False + chains=1, + draws=50, + compute_convergence_checks=False, + return_inferencedata=False, + random_seed=rng, ) + assert set(trace.varnames) == { + "theta1", + "theta1_missing", + "theta1_missing_interval__", + "theta2", + "theta2_missing", + } + # Make sure that the missing values are newly generated samples and that + # the observed and deterministic match assert np.all(0 < trace["theta1_missing"].mean(0)) assert np.all(0 < trace["theta2_missing"].mean(0)) - assert "theta1" not in trace.varnames - assert "theta2" not in trace.varnames + assert np.isclose(np.mean(trace["theta1"][:, obs1.mask] - trace["theta1_missing"]), 0) + assert np.isclose(np.mean(trace["theta2"][:, obs2.mask] - trace["theta2_missing"]), 0) - # Make sure that the observed values are newly generated samples and that - # the observed and deterministic matche - pp_idata = pm.sample_posterior_predictive(trace) + # Make sure that the observed values are unchanged + assert np.allclose(np.var(trace["theta1"][:, ~obs1.mask], 0), 0.0) + assert np.allclose(np.var(trace["theta2"][:, ~obs2.mask], 0), 0.0) + np.testing.assert_array_equal(trace["theta1"][0][~obs1.mask], obs1[~obs1.mask]) + np.testing.assert_array_equal(trace["theta2"][0][~obs2.mask], obs1[~obs2.mask]) + + pp_idata = pm.sample_posterior_predictive(trace, random_seed=rng) pp_trace = pp_idata.posterior_predictive.stack(sample=["chain", "draw"]).transpose( "sample", ... ) + assert set(pp_trace.keys()) == { + "theta1", + "theta1_observed", + "theta2", + "theta2_observed", + } + + # Make sure that the observed values are newly generated samples and that + # the observed and deterministic match assert np.all(np.var(pp_trace["theta1"], 0) > 0.0) assert np.all(np.var(pp_trace["theta2"], 0) > 0.0) assert np.isclose(