Skip to content

Commit

Permalink
Refactor utility to ignore the logprob of multiple variables while ke…
Browse files Browse the repository at this point in the history
…eping their interdependencies intact
  • Loading branch information
ricardoV94 committed Mar 13, 2023
1 parent 1cc9863 commit 9836d00
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 22 deletions.
22 changes: 2 additions & 20 deletions pymc/logprob/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

from pymc.logprob.abstract import MeasurableVariable, _logprob, logprob
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from pymc.logprob.utils import ignore_logprob
from pymc.logprob.utils import ignore_logprob, ignore_logprob_multiple_vars


@node_rewriter([BroadcastTo])
Expand Down Expand Up @@ -228,25 +228,7 @@ def find_measurable_stacks(
):
return None # pragma: no cover

# Make base_vars unmeasurable
base_to_unmeasurable_vars = {base_var: ignore_logprob(base_var) for base_var in base_vars}

def replacement_fn(var, replacements):
if var in base_to_unmeasurable_vars:
replacements[var] = base_to_unmeasurable_vars[var]
# We don't want to clone valued nodes. Assigning a var to itself in the
# replacements prevents this
elif var in rvs_to_values:
replacements[var] = var

return []

# TODO: Fix this import circularity!
from pymc.pytensorf import _replace_rvs_in_graphs

unmeasurable_base_vars, _ = _replace_rvs_in_graphs(
graphs=base_vars, replacement_fn=replacement_fn
)
unmeasurable_base_vars = ignore_logprob_multiple_vars(base_vars, rvs_to_values)

if is_join:
measurable_stack = MeasurableJoin()(axis, *unmeasurable_base_vars)
Expand Down
43 changes: 41 additions & 2 deletions pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,17 @@
import warnings

from copy import copy
from typing import Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple
from typing import (
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
)

import numpy as np

Expand Down Expand Up @@ -265,7 +275,7 @@ def diracdelta_logprob(op, values, *inputs, **kwargs):
def ignore_logprob(rv: TensorVariable) -> TensorVariable:
"""Return a duplicated variable that is ignored when creating logprob graphs
This is used in SymbolicDistributions that use other RVs as inputs but account
This is used in by MeasurableRVs that use other RVs as inputs but account
for their logp terms explicitly.
If the variable is already ignored, it is returned directly.
Expand Down Expand Up @@ -298,3 +308,32 @@ def reconsider_logprob(rv: TensorVariable) -> TensorVariable:
new_node.op = copy(new_node.op)
new_node.op.__class__ = original_op_type
return new_node.outputs[node.outputs.index(rv)]


def ignore_logprob_multiple_vars(
vars: Sequence[TensorVariable], rvs_to_values: Dict[TensorVariable, TensorVariable]
) -> List[TensorVariable]:
"""Return duplicated variables that are ignored when creating logprob graphs.
This function keeps any interdependencies between variables intact, after
making each "unmeasurable", whereas a sequential call to `ignore_logprob`
would not do this correctly.
"""
from pymc.pytensorf import _replace_rvs_in_graphs

measurable_vars_to_unmeasurable_vars = {
measurable_var: ignore_logprob(measurable_var) for measurable_var in vars
}

def replacement_fn(var, replacements):
if var in measurable_vars_to_unmeasurable_vars:
replacements[var] = measurable_vars_to_unmeasurable_vars[var]
# We don't want to clone valued nodes. Assigning a var to itself in the
# replacements prevents this
elif var in rvs_to_values:
replacements[var] = var

return []

unmeasurable_vars, _ = _replace_rvs_in_graphs(graphs=vars, replacement_fn=replacement_fn)
return unmeasurable_vars

0 comments on commit 9836d00

Please sign in to comment.