Skip to content

Commit

Permalink
Do not rely on tag information when transforming between rv and logp …
Browse files Browse the repository at this point in the history
…graphs
  • Loading branch information
ricardoV94 committed Nov 9, 2022
1 parent f80e45c commit 2cb57e1
Show file tree
Hide file tree
Showing 15 changed files with 156 additions and 100 deletions.
4 changes: 2 additions & 2 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def find_observations(model: "Model") -> Dict[str, Var]:
"""If there are observations available, return them as a dictionary."""
observations = {}
for obs in model.observed_RVs:
aux_obs = getattr(obs.tag, "observations", None)
aux_obs = model.rvs_to_values.get(obs, None)
if aux_obs is not None:
try:
obs_data = extract_obs_data(aux_obs)
Expand Down Expand Up @@ -261,7 +261,7 @@ def log_likelihood_vals_point(self, point, var, log_like_fun):

if isinstance(var.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1)):
try:
obs_data = extract_obs_data(var.tag.observations)
obs_data = extract_obs_data(self.model.rvs_to_values[var])
except TypeError:
warnings.warn(f"Could not extract data from symbolic observation {var}")

Expand Down
2 changes: 0 additions & 2 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
logcdf,
logp,
joint_logp,
joint_logpt,
)

from pymc.distributions.bound import Bound
Expand Down Expand Up @@ -199,7 +198,6 @@
"Censored",
"CAR",
"PolyaGamma",
"joint_logpt",
"joint_logp",
"logp",
"logcdf",
Expand Down
98 changes: 71 additions & 27 deletions pymc/distributions/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from aeppl.logprob import logcdf as logcdf_aeppl
from aeppl.logprob import logprob as logp_aeppl
from aeppl.tensor import MeasurableJoin
from aeppl.transforms import TransformValuesRewrite
from aeppl.transforms import RVTransform, TransformValuesRewrite
from aesara import tensor as at
from aesara.graph.basic import graph_inputs, io_toposort
from aesara.tensor.random.op import RandomVariable
Expand Down Expand Up @@ -122,12 +122,26 @@ def _get_scaling(
)


def joint_logpt(*args, **kwargs):
warnings.warn(
"joint_logpt has been deprecated. Use joint_logp instead.",
FutureWarning,
)
return joint_logp(*args, **kwargs)
def _check_no_rvs(logp_terms: Sequence[TensorVariable]):
# Raise if there are unexpected RandomVariables in the logp graph
# Only SimulatorRVs are allowed
from pymc.distributions.simulator import SimulatorRV

unexpected_rv_nodes = [
node
for node in aesara.graph.ancestors(logp_terms)
if (
node.owner
and isinstance(node.owner.op, RandomVariable)
and not isinstance(node.owner.op, SimulatorRV)
)
]
if unexpected_rv_nodes:
raise ValueError(
f"Random variables detected in the logp graph: {unexpected_rv_nodes}.\n"
"This can happen when DensityDist logp or Interval transform functions "
"reference nonlocal variables."
)


def joint_logp(
Expand Down Expand Up @@ -169,6 +183,10 @@ def joint_logp(
Sum the log-likelihood or return each term as a separate list item.
"""
warnings.warn(
"joint_logp has been deprecated, use model.logp instead",
FutureWarning,
)
# TODO: In future when we drop support for tag.value_var most of the following
# logic can be removed and logp can just be a wrapper function that calls aeppl's
# joint_logprob directly.
Expand Down Expand Up @@ -241,33 +259,15 @@ def joint_logp(
**kwargs,
)

# Raise if there are unexpected RandomVariables in the logp graph
# Only SimulatorRVs are allowed
from pymc.distributions.simulator import SimulatorRV

unexpected_rv_nodes = [
node
for node in aesara.graph.ancestors(list(temp_logp_var_dict.values()))
if (
node.owner
and isinstance(node.owner.op, RandomVariable)
and not isinstance(node.owner.op, SimulatorRV)
)
]
if unexpected_rv_nodes:
raise ValueError(
f"Random variables detected in the logp graph: {unexpected_rv_nodes}.\n"
"This can happen when DensityDist logp or Interval transform functions "
"reference nonlocal variables."
)

# aeppl returns the logp for every single value term we provided to it. This includes
# the extra values we plugged in above, so we filter those we actually wanted in the
# same order they were given in.
logp_var_dict = {}
for value_var in rv_values.values():
logp_var_dict[value_var] = temp_logp_var_dict[value_var]

_check_no_rvs(list(logp_var_dict.values()))

if scaling:
for value_var in logp_var_dict.keys():
if value_var in rv_scalings:
Expand All @@ -281,6 +281,50 @@ def joint_logp(
return logp_var


def _joint_logp(
rvs: Sequence[TensorVariable],
*,
rvs_to_values: Dict[TensorVariable, TensorVariable],
rvs_to_transforms: Optional[Dict[TensorVariable, RVTransform]] = None,
jacobian: bool = True,
rvs_to_total_size: Optional[Dict[TensorVariable, TensorVariable]] = None,
sum: bool = True,
**kwargs,
) -> List[TensorVariable]:

transform_rewrite = None
if rvs_to_transforms:
values_to_transforms = {rvs_to_values[rv]: transform for rv, transform in rvs_to_transforms}
transform_rewrite = TransformValuesRewrite(values_to_transforms)

temp_logp_terms = factorized_joint_logprob(
rvs_to_values,
extra_rewrites=transform_rewrite,
use_jacobian=jacobian,
**kwargs,
)

# aeppl returns the logp for every single value term we provided to it. This includes
# the extra values we plugged in above, so we filter those we actually wanted in the
# same order they were given in.
logp_terms = {}
for rv in rvs:
value_var = rvs_to_values[rv]
logp_term = temp_logp_terms[value_var]
total_size = rvs_to_total_size.get(rv, None)
if total_size:
scaling = _get_scaling(total_size, value_var.shape, value_var.ndim)
logp_term *= scaling
logp_terms[value_var] = logp_term

_check_no_rvs(list(logp_terms.values()))

if sum:
return at.sum([at.sum(factor) for factor in logp_terms.values()])
else:
return list(logp_terms.values())


def logp(rv: TensorVariable, value) -> TensorVariable:
"""Return the log-probability graph of a Random Variable"""

Expand Down
13 changes: 6 additions & 7 deletions pymc/initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import aesara.tensor as at
import numpy as np

from aeppl.transforms import RVTransform
from aesara.graph.basic import Variable
from aesara.graph.fg import FunctionGraph
from aesara.tensor.var import TensorVariable
Expand All @@ -43,9 +44,7 @@ def convert_str_to_rv_dict(
if isinstance(key, str):
if is_transformed_name(key):
rv = model[get_untransformed_name(key)]
initvals[rv] = model.rvs_to_values[rv].tag.transform.backward(
initval, *rv.owner.inputs
)
initvals[rv] = model.rvs_to_transforms[rv].backward(initval, *rv.owner.inputs)
else:
initvals[model[key]] = initval
else:
Expand Down Expand Up @@ -158,7 +157,7 @@ def make_initial_point_fn(

initial_values = make_initial_point_expression(
free_rvs=model.free_RVs,
rvs_to_values=model.rvs_to_values,
rvs_to_transforms=model.rvs_to_transforms,
initval_strategies=initval_strats,
jitter_rvs=jitter_rvs,
default_strategy=default_strategy,
Expand All @@ -172,7 +171,7 @@ def make_initial_point_fn(

varnames = []
for var in model.free_RVs:
transform = getattr(model.rvs_to_values[var].tag, "transform", None)
transform = model.rvs_to_transforms.get(var, None)
if transform is not None and return_transformed:
name = get_transformed_name(var.name, transform)
else:
Expand All @@ -197,7 +196,7 @@ def inner(seed, *args, **kwargs):
def make_initial_point_expression(
*,
free_rvs: Sequence[TensorVariable],
rvs_to_values: Dict[TensorVariable, TensorVariable],
rvs_to_transforms: Dict[TensorVariable, RVTransform],
initval_strategies: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]],
jitter_rvs: Set[TensorVariable] = None,
default_strategy: str = "moment",
Expand Down Expand Up @@ -265,7 +264,7 @@ def make_initial_point_expression(
else:
value = at.as_tensor(strategy, dtype=variable.dtype).astype(variable.dtype)

transform = getattr(rvs_to_values[variable].tag, "transform", None)
transform = rvs_to_transforms.get(variable, None)

if transform is not None:
value = transform.forward(value, *variable.owner.inputs)
Expand Down
Loading

0 comments on commit 2cb57e1

Please sign in to comment.