Skip to content

Commit

Permalink
Distinguish better observed from constant data
Browse files Browse the repository at this point in the history
This avoids needing to set dummy observed data when doing sample_posterior_predictive when that is not part of the generative graph.
  • Loading branch information
ricardoV94 committed May 3, 2024
1 parent 0ad689c commit 40a9ef1
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 35 deletions.
38 changes: 14 additions & 24 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from arviz import InferenceData, concat, rcParams
from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires
from pytensor.graph.basic import Constant
from pytensor.graph import ancestors
from pytensor.tensor.sharedvar import SharedVariable
from rich.progress import Console, Progress
from rich.theme import Theme
Expand Down Expand Up @@ -72,31 +72,21 @@ def find_observations(model: "Model") -> dict[str, Var]:

def find_constants(model: "Model") -> dict[str, Var]:
"""If there are constants available, return them as a dictionary."""
model_vars = model.basic_RVs + model.deterministics + model.potentials
value_vars = set(model.rvs_to_values.values())

# The constant data vars must be either pm.Data or TensorConstant or SharedVariable
def is_data(name, var, model) -> bool:
observations = find_observations(model)
return (
var not in model.deterministics
and var not in model.observed_RVs
and var not in model.free_RVs
and var not in model.potentials
and var not in model.value_vars
and name not in observations
and isinstance(var, Constant | SharedVariable)
)

# The assumption is that constants (like pm.Data) are named
# variables that aren't observed or free RVs, nor are they
# deterministics, and then we eliminate observations.
constant_data = {}
for name, var in model.named_vars.items():
if is_data(name, var, model):
if hasattr(var, "get_value"):
var = var.get_value()
elif hasattr(var, "data"):
var = var.data
constant_data[name] = var
for var in model.data_vars:
if var in value_vars:
# An observed value variable could also be part of the generative graph
if var not in ancestors(model_vars):
continue

if isinstance(var, SharedVariable):
var_value = var.get_value()
else:
var_value = var.data
constant_data[var.name] = var_value

return constant_data

Expand Down
2 changes: 1 addition & 1 deletion pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,6 @@ def Data(
length=xshape[d],
)

model.add_named_variable(x, dims=dims)
model.register_data_var(x, dims=dims)

return x
7 changes: 7 additions & 0 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ def __init__(
self.observed_RVs = treelist(parent=self.parent.observed_RVs)
self.deterministics = treelist(parent=self.parent.deterministics)
self.potentials = treelist(parent=self.parent.potentials)
self.data_vars = treelist(parent=self.parent.data_vars)
self._coords = self.parent._coords
self._dim_lengths = self.parent._dim_lengths
else:
Expand All @@ -544,6 +545,7 @@ def __init__(
self.observed_RVs = treelist()
self.deterministics = treelist()
self.potentials = treelist()
self.data_vars = treelist()
self._coords = {}
self._dim_lengths = {}
self.add_coords(coords)
Expand Down Expand Up @@ -1483,6 +1485,11 @@ def create_value_var(

return value_var

def register_data_var(self, data, dims=None):
"""Register a data variable with the model."""
self.data_vars.append(data)
self.add_named_variable(data, dims=dims)

def add_named_variable(self, var, dims: tuple[str | None, ...] | None = None):
"""Add a random graph variable to the named variables of the model.
Expand Down
17 changes: 9 additions & 8 deletions pymc/model/fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,30 +164,30 @@ def fgraph_from_model(
free_rvs = model.free_RVs
observed_rvs = model.observed_RVs
potentials = model.potentials
named_vars = model.named_vars.values()
# We copy Deterministics (Identity Op) so that they don't show in between "main" variables
# We later remove these Identity Ops when we have a Deterministic ModelVar Op as a separator
old_deterministics = model.deterministics
deterministics = [det if inlined_views else det.copy(det.name) for det in old_deterministics]
# Value variables (we also have to decide whether to inline named ones)
old_value_vars = list(rvs_to_values.values())
unnamed_value_vars = [val for val in old_value_vars if val not in named_vars]
data_vars = model.data_vars
unnamed_value_vars = [val for val in old_value_vars if val not in data_vars]
named_value_vars = [
val if inlined_views else val.copy(val.name) for val in old_value_vars if val in named_vars
val if inlined_views else val.copy(name=val.name)
for val in old_value_vars
if val in data_vars
]
value_vars = old_value_vars.copy()
if inlined_views:
# In this case we want to use the named_value_vars as the value_vars in RVs
for named_val in named_value_vars:
idx = value_vars.index(named_val)
value_vars[idx] = named_val
# Other variables that are in named_vars but are not any of the categories above (e.g., Data)
# We use the same trick as deterministics!
accounted_for = set(free_rvs + observed_rvs + potentials + old_deterministics + old_value_vars)
# Data vars that are not value vars
other_named_vars = [
var if inlined_views else var.copy(var.name)
for var in named_vars
if var not in accounted_for
for var in data_vars
if var not in old_value_vars
]

model_vars = (
Expand Down Expand Up @@ -339,6 +339,7 @@ def first_non_model_var(var):
model.deterministics.append(var)
elif isinstance(model_var.owner.op, ModelNamed):
var, *dims = model_var.owner.inputs
model.data_vars.append(var)
else:
raise TypeError(f"Unexpected ModelVar type {type(model_var)}")

Expand Down
28 changes: 26 additions & 2 deletions tests/backends/test_arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,14 +454,38 @@ def test_constant_data(self, use_context):
test_dict = {
"posterior": ["beta"],
"observed_data": ["obs"],
"constant_data": ["x", "y", "beta_sigma"],
"constant_data": ["x", "beta_sigma"],
}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails
assert inference_data.log_likelihood["obs"].shape == (2, 100, 3)
# test that scalars are dimensionless in constant_data (issue #6755)
assert inference_data.constant_data["beta_sigma"].ndim == 0

@pytest.mark.parametrize("constant_in_generative_graph", [True, False])
def test_observed_data_also_constant(self, constant_in_generative_graph):
"""Test that wen the same variable is used as constant data and observed data, it shows up in both groups."""
with pm.Model(coords={"trial": [0, 1, 2]}) as model:
x = pm.Data("x", [1.0, 2.0, 3.0], dims=["trial"])
sigma = pm.HalfNormal("sigma", 1)
mu = x - 1 if constant_in_generative_graph else 0
pm.Normal("y", mu, sigma, observed=x, dims=["trial"])

trace = pm.sample_prior_predictive(100, return_inferencedata=False)

inference_data = to_inference_data(prior=trace, model=model, log_likelihood=False)

test_dict = {
"prior": ["sigma"],
"observed_data": ["y"],
}
if constant_in_generative_graph:
test_dict["constant_data"] = ["x"]
else:
test_dict["~constant_data"] = []
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails

def test_predictions_constant_data(self):
with pm.Model():
x = pm.Data("x", [1.0, 2.0, 3.0])
Expand Down Expand Up @@ -548,7 +572,7 @@ def test_priors_separation(self, use_context):
"prior": ["beta", "~obs"],
"observed_data": ["obs"],
"prior_predictive": ["obs"],
"constant_data": ["x", "y"],
"constant_data": ["x"],
}
if use_context:
with model:
Expand Down
49 changes: 49 additions & 0 deletions tests/sampling/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,55 @@ def test_logging_sampled_basic_rvs_posterior_mutable(self, mock_sample_results,
]
caplog.clear()

def test_observed_data_needed_in_pp(self):
# Model where y_data is not part of the generative graph.
# It shouldn't be needed to set a dummy value for posterior predictive sampling

with pm.Model(coords={"trial": range(5), "feature": range(3)}) as m:
x_data = pm.Data("x_data", np.random.normal(size=(5, 3)), dims=("trial", "feat"))
y_data = pm.Data("y_data", np.random.normal(size=(5,)), dims=("trial",))
sigma = pm.HalfNormal("sigma")
mu = x_data.sum(-1)
pm.Normal("y", mu=mu, sigma=sigma, observed=y_data, shape=mu.shape, dims=("trial",))

prior = pm.sample_prior_predictive(samples=25).prior

fake_idata = InferenceData(posterior=prior)

new_coords = {"trial": range(2), "feature": range(3)}
new_x_data = np.random.normal(size=(2, 3))
with m:
pm.set_data(
{
"x_data": new_x_data,
},
coords=new_coords,
)
pp = pm.sample_posterior_predictive(fake_idata, predictions=True, progressbar=False)
assert pp.predictions["y"].shape == (1, 25, 2)

# In this case y_data is part of the generative graph, so we must set it to a compatible value
with pm.Model(coords={"trial": range(5), "feature": range(3)}) as m:
x_data = pm.Data("x_data", np.random.normal(size=(5, 3)), dims=("trial", "feat"))
y_data = pm.Data("y_data", np.random.normal(size=(5,)), dims=("trial",))
sigma = pm.HalfNormal("sigma")
mu = (y_data.sum() * x_data).sum(-1)
pm.Normal("y", mu=mu, sigma=sigma, observed=y_data, shape=mu.shape, dims=("trial",))

prior = pm.sample_prior_predictive(samples=25).prior

fake_idata = InferenceData(posterior=prior)

with m:
pm.set_data({"x_data": new_x_data}, coords=new_coords)
with pytest.raises(ValueError, match="conflicting sizes for dimension 'trial'"):
pm.sample_posterior_predictive(fake_idata, predictions=True, progressbar=False)

new_y_data = np.random.normal(size=(2,))
with m:
pm.set_data({"y_data": new_y_data})
assert pp.predictions["y"].shape == (1, 25, 2)


@pytest.fixture(scope="class")
def point_list_arg_bug_fixture() -> tuple[pm.Model, pm.backends.base.MultiTrace]:
Expand Down

0 comments on commit 40a9ef1

Please sign in to comment.