Skip to content

Commit

Permalink
Make rvs_to_values work with non-RandomVariables
Browse files Browse the repository at this point in the history
No longer returns replacements dict from `rvs_to_values` as the keys will be the cloned `rvs` which are not useful to the caller.
  • Loading branch information
ricardoV94 committed Sep 5, 2022
1 parent 903aa60 commit 9be5a54
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 115 deletions.
128 changes: 57 additions & 71 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

from typing import (
Callable,
Dict,
Expand Down Expand Up @@ -147,32 +145,6 @@ def dataframe_to_tensor_variable(df: pd.DataFrame, *args, **kwargs) -> TensorVar
return at.as_tensor_variable(df.to_numpy(), *args, **kwargs)


def extract_rv_and_value_vars(
var: TensorVariable,
) -> Tuple[TensorVariable, TensorVariable]:
"""Return a random variable and it's observations or value variable, or ``None``.
Parameters
==========
var
A variable corresponding to a ``RandomVariable``.
Returns
=======
The first value in the tuple is the ``RandomVariable``, and the second is the
measure/log-likelihood value variable that corresponds with the latter.
"""
if not var.owner:
return None, None

if isinstance(var.owner.op, RandomVariable):
rv_value = getattr(var.tag, "observations", getattr(var.tag, "value_var", None))
return var, rv_value

return None, None


def extract_obs_data(x: TensorVariable) -> np.ndarray:
"""Extract data from observed symbolic variables.
Expand All @@ -198,22 +170,40 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray:
raise TypeError(f"Data cannot be extracted from {x}")


def extract_rv_and_value_vars(
random_var: TensorVariable,
) -> Tuple[TensorVariable, TensorVariable]:
"""Return a random variable and it's observations or value variable, or ``None``.
Parameters
==========
var
A variable corresponding to a ``RandomVariable``.
Returns
=======
The first value in the tuple is the ``RandomVariable``, and the second is the
measure/log-likelihood value variable that corresponds with the latter.
"""
if not random_var.owner:
return None, None

value_var = getattr(random_var.tag, "observations", getattr(random_var.tag, "value_var", None))
return random_var, value_var


def walk_model(
graphs: Iterable[TensorVariable],
walk_past_rvs: bool = False,
stop_at_vars: Optional[Set[TensorVariable]] = None,
expand_fn: Callable[[TensorVariable], Iterable[TensorVariable]] = lambda var: [],
) -> Generator[TensorVariable, None, None]:
"""Walk model graphs and yield their nodes.
By default, these walks will not go past ``RandomVariable`` nodes.
Parameters
==========
graphs
The graphs to walk.
walk_past_rvs
If ``True``, the walk will not terminate at ``RandomVariable``s.
stop_at_vars
A list of variables at which the walk will terminate.
expand_fn
Expand All @@ -225,16 +215,12 @@ def walk_model(
def expand(var):
new_vars = expand_fn(var)

if (
var.owner
and (walk_past_rvs or not isinstance(var.owner.op, RandomVariable))
and (var not in stop_at_vars)
):
if var.owner and var not in stop_at_vars:
new_vars.extend(reversed(var.owner.inputs))

return new_vars

yield from walk(graphs, expand, False)
yield from walk(graphs, expand, bfs=False)


def replace_rvs_in_graphs(
Expand Down Expand Up @@ -263,7 +249,11 @@ def replace_rvs_in_graphs(

def expand_replace(var):
new_nodes = []
if var.owner and isinstance(var.owner.op, RandomVariable):
if var.owner:
# Call replacement_fn to update replacements dict inplace and, optionally,
# specify new nodes that should also be walked for replacements. This
# includes `value` variables that are not simple input variables, and may
# contain other `random` variables in their graphs (e.g., IntervalTransform)
new_nodes.extend(replacement_fn(var, replacements))
return new_nodes

Expand All @@ -290,10 +280,10 @@ def expand_replace(var):

def rvs_to_value_vars(
graphs: Iterable[TensorVariable],
apply_transforms: bool = False,
apply_transforms: bool = True,
initial_replacements: Optional[Dict[TensorVariable, TensorVariable]] = None,
**kwargs,
) -> Tuple[TensorVariable, Dict[TensorVariable, TensorVariable]]:
) -> TensorVariable:
"""Clone and replace random variables in graphs with their value variables.
This will *not* recompute test values in the resulting graphs.
Expand All @@ -309,38 +299,27 @@ def rvs_to_value_vars(
"""

# Avoid circular dependency
from pymc.distributions import NoDistribution

def transform_replacements(var, replacements):
rv_var, rv_value_var = extract_rv_and_value_vars(var)
def populate_replacements(
var: TensorVariable, replacements: Dict[TensorVariable, TensorVariable]
) -> List[TensorVariable]:
# Populate replacements dict with {rv: value} pairs indicating which graph
# RVs should be replaced by what value variables.
random_var, value_var = extract_rv_and_value_vars(var)

if rv_value_var is None:
# If RandomVariable does not have a value_var and corresponds to
# a NoDistribution, we allow further replacements in upstream graph
if isinstance(rv_var.owner.op, NoDistribution):
return rv_var.owner.inputs
# No value variable to replace RV with
if value_var is None:
return []

else:
warnings.warn(
f"No value variable found for {rv_var}; "
"the random variable will not be replaced."
)
return []

transform = getattr(rv_value_var.tag, "transform", None)

if transform is None or not apply_transforms:
replacements[var] = rv_value_var
# In case the value variable is itself a graph, we walk it for
# potential replacements
return [rv_value_var]
transform = getattr(value_var.tag, "transform", None)
if transform is not None and apply_transforms:
# We want to replace uses of the RV by the back-transformation of its value
value_var = transform.backward(value_var, *random_var.owner.inputs)

trans_rv_value = transform.backward(rv_value_var, *rv_var.owner.inputs)
replacements[var] = trans_rv_value
replacements[random_var] = value_var

# Walk the transformed variable and make replacements
return [trans_rv_value]
# Also walk the graph of the value variable to make any additional replacements
# if that is not a simple input variable
return [value_var]

# Clone original graphs
inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)]
Expand All @@ -352,7 +331,14 @@ def transform_replacements(var, replacements):
equiv.get(k, k): equiv.get(v, v) for k, v in initial_replacements.items()
}

return replace_rvs_in_graphs(graphs, transform_replacements, initial_replacements, **kwargs)
graphs, _ = replace_rvs_in_graphs(
graphs,
replacement_fn=populate_replacements,
initial_replacements=initial_replacements,
**kwargs,
)

return graphs


def inputvars(a):
Expand Down
2 changes: 1 addition & 1 deletion pymc/gp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def replace_with_values(vars_needed, replacements=None, model=None):
model = modelcontext(model)

inputs, input_names = [], []
for rv in walk_model(vars_needed, walk_past_rvs=True):
for rv in walk_model(vars_needed):
if rv in model.named_vars.values() and not isinstance(rv, SharedVariable):
inputs.append(rv)
input_names.append(rv.name)
Expand Down
8 changes: 4 additions & 4 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ def logp(
# Replace random variables by their value variables in potential terms
potential_logps = []
if potentials:
potential_logps, _ = rvs_to_value_vars(potentials, apply_transforms=True)
potential_logps = rvs_to_value_vars(potentials)

logp_factors = [None] * len(varlist)
for logp_order, logp in zip((rv_order + potential_order), (rv_logps + potential_logps)):
Expand Down Expand Up @@ -935,7 +935,7 @@ def potentiallogp(self) -> Variable:
"""Aesara scalar of log-probability of the Potential terms"""
# Convert random variables in Potential expression into their log-likelihood
# inputs and apply their transforms, if any
potentials, _ = rvs_to_value_vars(self.potentials, apply_transforms=True)
potentials = rvs_to_value_vars(self.potentials)
if potentials:
return at.sum([at.sum(factor) for factor in potentials])
else:
Expand Down Expand Up @@ -976,10 +976,10 @@ def unobserved_value_vars(self):
vars.append(value_var)

# Remove rvs from untransformed values graph
untransformed_vars, _ = rvs_to_value_vars(untransformed_vars, apply_transforms=True)
untransformed_vars = rvs_to_value_vars(untransformed_vars)

# Remove rvs from deterministics graph
deterministics, _ = rvs_to_value_vars(self.deterministics, apply_transforms=True)
deterministics = rvs_to_value_vars(self.deterministics)

return vars + untransformed_vars + deterministics

Expand Down
2 changes: 1 addition & 1 deletion pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def __init__(self, vars, proposal="uniform", order="random", model=None):

if isinstance(distr, CategoricalRV):
k_graph = rv_var.owner.inputs[3].shape[-1]
(k_graph,), _ = rvs_to_value_vars((k_graph,), apply_transforms=True)
(k_graph,) = rvs_to_value_vars((k_graph,))
k = model.compile_fn(k_graph, inputs=model.value_vars, on_unused_input="ignore")(
initial_point
)
Expand Down
2 changes: 1 addition & 1 deletion pymc/tests/distributions/test_logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_joint_logp_basic():
with pytest.warns(FutureWarning):
b_logpt = joint_logpt(b, b_value_var, sum=False)

res_ancestors = list(walk_model(b_logp, walk_past_rvs=True))
res_ancestors = list(walk_model(b_logp))
res_rv_ancestors = [
v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable)
]
Expand Down
Loading

0 comments on commit 9be5a54

Please sign in to comment.