Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise ValueError if random variables are present in the logp graph #5614

Merged
merged 1 commit into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions pymc/distributions/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from aeppl.logprob import logprob as logp_aeppl
from aeppl.transforms import TransformValuesOpt
from aesara.graph.basic import graph_inputs, io_toposort
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
Expand Down Expand Up @@ -223,6 +224,26 @@ def joint_logpt(
tmp_rvs_to_values, extra_rewrites=transform_opt, use_jacobian=jacobian, **kwargs
)

# Raise if there are unexpected RandomVariables in the logp graph
# Only SimulatorRVs are allowed
from pymc.distributions.simulator import SimulatorRV

unexpected_rv_nodes = [
node
for node in aesara.graph.ancestors(list(temp_logp_var_dict.values()))
if (
node.owner
and isinstance(node.owner.op, RandomVariable)
and not isinstance(node.owner.op, SimulatorRV)
)
]
if unexpected_rv_nodes:
raise ValueError(
f"Random variables detected in the logp graph: {unexpected_rv_nodes}.\n"
"This can happen when DensityDist logp or Interval transform functions "
"reference nonlocal variables."
)

# aeppl returns the logpt for every single value term we provided to it. This includes
# the extra values we plugged in above, so we filter those we actually wanted in the
# same order they were given in.
Expand Down
10 changes: 10 additions & 0 deletions pymc/tests/test_logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Subtensor,
)

from pymc import DensityDist
from pymc.aesaraf import floatX, walk_model
from pymc.distributions.continuous import HalfFlat, Normal, TruncatedNormal, Uniform
from pymc.distributions.discrete import Bernoulli
Expand Down Expand Up @@ -217,3 +218,12 @@ def test_model_unchanged_logprob_access():
model.logpt()
new_inputs = set(aesara.graph.graph_inputs([c]))
assert original_inputs == new_inputs


def test_unexpected_rvs():
with Model() as model:
x = Normal("x")
y = DensityDist("y", logp=lambda *args: x)

with pytest.raises(ValueError, match="^Random variables detected in the logp graph"):
model.logpt()
6 changes: 3 additions & 3 deletions pymc/tests/test_parallel_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,11 @@ def test_spawn_densitydist_bound_method():
N = 100
with pm.Model() as model:
mu = pm.Normal("mu", 0, 1)
normal_dist = pm.Normal.dist(mu, 1, size=N)

def logp(x):
def logp(x, mu):
normal_dist = pm.Normal.dist(mu, 1, size=N)
out = pm.logp(normal_dist, x)
return out

obs = pm.DensityDist("density_dist", logp=logp, observed=np.random.randn(N), size=N)
obs = pm.DensityDist("density_dist", mu, logp=logp, observed=np.random.randn(N), size=N)
Comment on lines 201 to +210
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a perfect illustration of the problem that's detected by this PR

pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn")