Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make rvs_to_values work with non-RandomVariables #6101

Merged
merged 1 commit into from
Sep 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 36 additions & 70 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

from typing import (
Callable,
Dict,
Expand Down Expand Up @@ -147,32 +145,6 @@ def dataframe_to_tensor_variable(df: pd.DataFrame, *args, **kwargs) -> TensorVar
return at.as_tensor_variable(df.to_numpy(), *args, **kwargs)


def extract_rv_and_value_vars(
var: TensorVariable,
) -> Tuple[TensorVariable, TensorVariable]:
"""Return a random variable and it's observations or value variable, or ``None``.

Parameters
==========
var
A variable corresponding to a ``RandomVariable``.

Returns
=======
The first value in the tuple is the ``RandomVariable``, and the second is the
measure/log-likelihood value variable that corresponds with the latter.

"""
if not var.owner:
return None, None

if isinstance(var.owner.op, RandomVariable):
rv_value = getattr(var.tag, "observations", getattr(var.tag, "value_var", None))
return var, rv_value

return None, None


def extract_obs_data(x: TensorVariable) -> np.ndarray:
"""Extract data from observed symbolic variables.

Expand Down Expand Up @@ -200,20 +172,15 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray:

def walk_model(
graphs: Iterable[TensorVariable],
walk_past_rvs: bool = False,
stop_at_vars: Optional[Set[TensorVariable]] = None,
expand_fn: Callable[[TensorVariable], Iterable[TensorVariable]] = lambda var: [],
) -> Generator[TensorVariable, None, None]:
"""Walk model graphs and yield their nodes.

By default, these walks will not go past ``RandomVariable`` nodes.

Parameters
==========
graphs
The graphs to walk.
walk_past_rvs
If ``True``, the walk will not terminate at ``RandomVariable``s.
stop_at_vars
A list of variables at which the walk will terminate.
expand_fn
Expand All @@ -225,16 +192,12 @@ def walk_model(
def expand(var):
new_vars = expand_fn(var)

if (
var.owner
and (walk_past_rvs or not isinstance(var.owner.op, RandomVariable))
and (var not in stop_at_vars)
):
if var.owner and var not in stop_at_vars:
new_vars.extend(reversed(var.owner.inputs))

return new_vars

yield from walk(graphs, expand, False)
yield from walk(graphs, expand, bfs=False)


def replace_rvs_in_graphs(
Expand Down Expand Up @@ -263,7 +226,11 @@ def replace_rvs_in_graphs(

def expand_replace(var):
new_nodes = []
if var.owner and isinstance(var.owner.op, RandomVariable):
if var.owner:
# Call replacement_fn to update replacements dict inplace and, optionally,
# specify new nodes that should also be walked for replacements. This
# includes `value` variables that are not simple input variables, and may
# contain other `random` variables in their graphs (e.g., IntervalTransform)
new_nodes.extend(replacement_fn(var, replacements))
return new_nodes

Expand All @@ -290,10 +257,10 @@ def expand_replace(var):

def rvs_to_value_vars(
graphs: Iterable[TensorVariable],
apply_transforms: bool = False,
apply_transforms: bool = True,
initial_replacements: Optional[Dict[TensorVariable, TensorVariable]] = None,
**kwargs,
) -> Tuple[TensorVariable, Dict[TensorVariable, TensorVariable]]:
) -> TensorVariable:
"""Clone and replace random variables in graphs with their value variables.

This will *not* recompute test values in the resulting graphs.
Expand All @@ -309,38 +276,30 @@ def rvs_to_value_vars(

"""

# Avoid circular dependency
from pymc.distributions import NoDistribution

def transform_replacements(var, replacements):
rv_var, rv_value_var = extract_rv_and_value_vars(var)

if rv_value_var is None:
# If RandomVariable does not have a value_var and corresponds to
# a NoDistribution, we allow further replacements in upstream graph
if isinstance(rv_var.owner.op, NoDistribution):
return rv_var.owner.inputs
def populate_replacements(
random_var: TensorVariable, replacements: Dict[TensorVariable, TensorVariable]
) -> List[TensorVariable]:
# Populate replacements dict with {rv: value} pairs indicating which graph
# RVs should be replaced by what value variables.

else:
warnings.warn(
f"No value variable found for {rv_var}; "
"the random variable will not be replaced."
)
return []
value_var = getattr(
random_var.tag, "observations", getattr(random_var.tag, "value_var", None)
)

transform = getattr(rv_value_var.tag, "transform", None)
# No value variable to replace RV with
if value_var is None:
return []

if transform is None or not apply_transforms:
replacements[var] = rv_value_var
# In case the value variable is itself a graph, we walk it for
# potential replacements
return [rv_value_var]
transform = getattr(value_var.tag, "transform", None)
if transform is not None and apply_transforms:
# We want to replace uses of the RV by the back-transformation of its value
value_var = transform.backward(value_var, *random_var.owner.inputs)

trans_rv_value = transform.backward(rv_value_var, *rv_var.owner.inputs)
replacements[var] = trans_rv_value
replacements[random_var] = value_var

# Walk the transformed variable and make replacements
return [trans_rv_value]
# Also walk the graph of the value variable to make any additional replacements
# if that is not a simple input variable
return [value_var]

# Clone original graphs
inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)]
Expand All @@ -352,7 +311,14 @@ def transform_replacements(var, replacements):
equiv.get(k, k): equiv.get(v, v) for k, v in initial_replacements.items()
}

return replace_rvs_in_graphs(graphs, transform_replacements, initial_replacements, **kwargs)
graphs, _ = replace_rvs_in_graphs(
graphs,
replacement_fn=populate_replacements,
initial_replacements=initial_replacements,
**kwargs,
)

return graphs


def inputvars(a):
Expand Down
2 changes: 1 addition & 1 deletion pymc/gp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def replace_with_values(vars_needed, replacements=None, model=None):
model = modelcontext(model)

inputs, input_names = [], []
for rv in walk_model(vars_needed, walk_past_rvs=True):
for rv in walk_model(vars_needed):
if rv in model.named_vars.values() and not isinstance(rv, SharedVariable):
inputs.append(rv)
input_names.append(rv.name)
Expand Down
8 changes: 4 additions & 4 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ def logp(
# Replace random variables by their value variables in potential terms
potential_logps = []
if potentials:
potential_logps, _ = rvs_to_value_vars(potentials, apply_transforms=True)
potential_logps = rvs_to_value_vars(potentials)

logp_factors = [None] * len(varlist)
for logp_order, logp in zip((rv_order + potential_order), (rv_logps + potential_logps)):
Expand Down Expand Up @@ -935,7 +935,7 @@ def potentiallogp(self) -> Variable:
"""Aesara scalar of log-probability of the Potential terms"""
# Convert random variables in Potential expression into their log-likelihood
# inputs and apply their transforms, if any
potentials, _ = rvs_to_value_vars(self.potentials, apply_transforms=True)
potentials = rvs_to_value_vars(self.potentials)
if potentials:
return at.sum([at.sum(factor) for factor in potentials])
else:
Expand Down Expand Up @@ -976,10 +976,10 @@ def unobserved_value_vars(self):
vars.append(value_var)

# Remove rvs from untransformed values graph
untransformed_vars, _ = rvs_to_value_vars(untransformed_vars, apply_transforms=True)
untransformed_vars = rvs_to_value_vars(untransformed_vars)

# Remove rvs from deterministics graph
deterministics, _ = rvs_to_value_vars(self.deterministics, apply_transforms=True)
deterministics = rvs_to_value_vars(self.deterministics)

return vars + untransformed_vars + deterministics

Expand Down
2 changes: 1 addition & 1 deletion pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def __init__(self, vars, proposal="uniform", order="random", model=None):

if isinstance(distr, CategoricalRV):
k_graph = rv_var.owner.inputs[3].shape[-1]
(k_graph,), _ = rvs_to_value_vars((k_graph,), apply_transforms=True)
(k_graph,) = rvs_to_value_vars((k_graph,))
k = model.compile_fn(k_graph, inputs=model.value_vars, on_unused_input="ignore")(
initial_point
)
Expand Down
2 changes: 1 addition & 1 deletion pymc/tests/distributions/test_logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_joint_logp_basic():
with pytest.warns(FutureWarning):
b_logpt = joint_logpt(b, b_value_var, sum=False)

res_ancestors = list(walk_model(b_logp, walk_past_rvs=True))
res_ancestors = list(walk_model(b_logp))
res_rv_ancestors = [
v for v in res_ancestors if v.owner and isinstance(v.owner.op, RandomVariable)
]
Expand Down
Loading