Skip to content

Commit

Permalink
Fix need for dummy data when changing coordinates in sample_posterior…
Browse files Browse the repository at this point in the history
…_predictive
  • Loading branch information
ricardoV94 committed May 3, 2024
1 parent 0ad689c commit 94cc61d
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 40 deletions.
34 changes: 10 additions & 24 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

from arviz import InferenceData, concat, rcParams
from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires
from pytensor.graph.basic import Constant
from pytensor.tensor.sharedvar import SharedVariable
from rich.progress import Console, Progress
from rich.theme import Theme
Expand Down Expand Up @@ -72,31 +71,18 @@ def find_observations(model: "Model") -> dict[str, Var]:

def find_constants(model: "Model") -> dict[str, Var]:
"""If there are constants available, return them as a dictionary."""
value_vars = set(model.rvs_to_values.values())

# The constant data vars must be either pm.Data or TensorConstant or SharedVariable
def is_data(name, var, model) -> bool:
observations = find_observations(model)
return (
var not in model.deterministics
and var not in model.observed_RVs
and var not in model.free_RVs
and var not in model.potentials
and var not in model.value_vars
and name not in observations
and isinstance(var, Constant | SharedVariable)
)

# The assumption is that constants (like pm.Data) are named
# variables that aren't observed or free RVs, nor are they
# deterministics, and then we eliminate observations.
constant_data = {}
for name, var in model.named_vars.items():
if is_data(name, var, model):
if hasattr(var, "get_value"):
var = var.get_value()
elif hasattr(var, "data"):
var = var.data
constant_data[name] = var
for var in model.data_vars:
if var in value_vars:
continue

if isinstance(var, SharedVariable):
var_value = var.get_value()
else:
var_value = var.data
constant_data[var.name] = var_value

return constant_data

Expand Down
2 changes: 1 addition & 1 deletion pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,6 @@ def Data(
length=xshape[d],
)

model.add_named_variable(x, dims=dims)
model.register_data_var(x, dims=dims)

return x
7 changes: 7 additions & 0 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ def __init__(
self.observed_RVs = treelist(parent=self.parent.observed_RVs)
self.deterministics = treelist(parent=self.parent.deterministics)
self.potentials = treelist(parent=self.parent.potentials)
self.data_vars = treelist(parent=self.parent.data_vars)
self._coords = self.parent._coords
self._dim_lengths = self.parent._dim_lengths
else:
Expand All @@ -544,6 +545,7 @@ def __init__(
self.observed_RVs = treelist()
self.deterministics = treelist()
self.potentials = treelist()
self.data_vars = treelist()
self._coords = {}
self._dim_lengths = {}
self.add_coords(coords)
Expand Down Expand Up @@ -1483,6 +1485,11 @@ def create_value_var(

return value_var

def register_data_var(self, data, dims=None):
"""Register a data variable with the model."""
self.data_vars.append(data)
self.add_named_variable(data, dims=dims)

def add_named_variable(self, var, dims: tuple[str | None, ...] | None = None):
"""Add a random graph variable to the named variables of the model.
Expand Down
22 changes: 9 additions & 13 deletions pymc/model/fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,34 +164,29 @@ def fgraph_from_model(
free_rvs = model.free_RVs
observed_rvs = model.observed_RVs
potentials = model.potentials
data_vars = model.data_vars
named_vars = model.named_vars.values()
# We copy Deterministics (Identity Op) so that they don't show in between "main" variables
# We later remove these Identity Ops when we have a Deterministic ModelVar Op as a separator
old_deterministics = model.deterministics
deterministics = [det if inlined_views else det.copy(det.name) for det in old_deterministics]
# Value variables (we also have to decide whether to inline named ones)
old_value_vars = list(rvs_to_values.values())
unnamed_value_vars = [val for val in old_value_vars if val not in named_vars]
data_vars = model.data_vars
unnamed_value_vars = [val for val in old_value_vars if val not in data_vars]
named_value_vars = [
val if inlined_views else val.copy(val.name) for val in old_value_vars if val in named_vars
val if inlined_views else val.copy() for val in old_value_vars if val in named_vars
]
value_vars = old_value_vars.copy()
if inlined_views:
# In this case we want to use the named_value_vars as the value_vars in RVs
for named_val in named_value_vars:
idx = value_vars.index(named_val)
value_vars[idx] = named_val
# Other variables that are in named_vars but are not any of the categories above (e.g., Data)
# We use the same trick as deterministics!
accounted_for = set(free_rvs + observed_rvs + potentials + old_deterministics + old_value_vars)
other_named_vars = [
var if inlined_views else var.copy(var.name)
for var in named_vars
if var not in accounted_for
]

# Data vars can be value variables or just named variables
data_vars = [var for var in data_vars if var not in old_value_vars]
model_vars = (
rvs + potentials + deterministics + other_named_vars + named_value_vars + unnamed_value_vars
rvs + potentials + deterministics + data_vars + named_value_vars + unnamed_value_vars
)

memo = {}
Expand Down Expand Up @@ -230,7 +225,7 @@ def fgraph_from_model(
}
potentials = [memo[k] for k in potentials]
deterministics = [memo[k] for k in deterministics]
named_vars = [memo[k] for k in other_named_vars + named_value_vars]
named_vars = [memo[k] for k in data_vars + named_value_vars]

vars = fgraph.outputs
new_vars = []
Expand Down Expand Up @@ -339,6 +334,7 @@ def first_non_model_var(var):
model.deterministics.append(var)
elif isinstance(model_var.owner.op, ModelNamed):
var, *dims = model_var.owner.inputs
model.data_vars.append(var)
else:
raise TypeError(f"Unexpected ModelVar type {type(model_var)}")

Expand Down
4 changes: 2 additions & 2 deletions tests/backends/test_arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def test_constant_data(self, use_context):
test_dict = {
"posterior": ["beta"],
"observed_data": ["obs"],
"constant_data": ["x", "y", "beta_sigma"],
"constant_data": ["x", "beta_sigma"],
}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails
Expand Down Expand Up @@ -548,7 +548,7 @@ def test_priors_separation(self, use_context):
"prior": ["beta", "~obs"],
"observed_data": ["obs"],
"prior_predictive": ["obs"],
"constant_data": ["x", "y"],
"constant_data": ["x"],
}
if use_context:
with model:
Expand Down
23 changes: 23 additions & 0 deletions tests/sampling/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,29 @@ def test_logging_sampled_basic_rvs_posterior_mutable(self, mock_sample_results,
]
caplog.clear()

def test_dummy_data_not_needed_in_pp(self):
with pm.Model(coords={"trial": range(5), "feature": range(3)}) as m:
x_data = pm.Data("x_data", np.random.normal(size=(5, 3)), dims=("trial", "feat"))
y_data = pm.Data("y_data", np.random.normal(size=(5,)), dims=("trial",))
sigma = pm.HalfNormal("sigma")
mu = pm.math.ones_like(x_data).sum(-1)
pm.Normal("y", mu=mu, sigma=sigma, observed=y_data, shape=mu.shape, dims=("trial",))

prior = pm.sample_prior_predictive(samples=25).prior

fake_idata = InferenceData(posterior=prior)

new_coords = {"trial": range(2), "feature": range(3)}
with m:
pm.set_data(
{
"x_data": np.random.normal(size=(2, 3)),
},
coords=new_coords,
)
pp = pm.sample_posterior_predictive(fake_idata, predictions=True, progressbar=False)
assert pp.predictions["y"].shape == (1, 25, 2)


@pytest.fixture(scope="class")
def point_list_arg_bug_fixture() -> tuple[pm.Model, pm.backends.base.MultiTrace]:
Expand Down

0 comments on commit 94cc61d

Please sign in to comment.