Skip to content

Commit

Permalink
Rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 committed Mar 8, 2023
1 parent 9954976 commit 8d3ec11
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 60 deletions.
4 changes: 2 additions & 2 deletions causal_pyro/counterfactual/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import pyro

from causal_pyro.counterfactual.internals import IndexPlatesMessenger
from causal_pyro.primitives import IndexSet, scatter
from causal_pyro.counterfactual.conditioning import (
AmbiguousConditioningReparam,
AmbiguousConditioningStrategy,
AutoFactualConditioning,
)
from causal_pyro.counterfactual.internals import IndexPlatesMessenger
from causal_pyro.primitives import IndexSet, scatter

CondStrategy = Union[
Dict[str, AmbiguousConditioningReparam], AmbiguousConditioningStrategy
Expand Down
2 changes: 1 addition & 1 deletion causal_pyro/counterfactual/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _scatter_dict(
):
"""
Scatters a dictionary of disjoint masked values into a single value
using repeated calls to :func:``causal_pyro.internals.scatter``.
using repeated calls to :func:``scatter``.
:param partitioned_values: A dictionary mapping index sets to values.
:return: A single value.
Expand Down
57 changes: 2 additions & 55 deletions causal_pyro/counterfactual/selection.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
from typing import Any, Dict, Optional, TypeVar
from typing import Any, Dict

import pyro
import pyro.infer.reparam
import torch

from causal_pyro.counterfactual.internals import (
get_index_plates,
get_sample_msg_device,
indexset_as_mask,
)
from causal_pyro.primitives import IndexSet, merge


T = TypeVar("T")
from causal_pyro.primitives import IndexSet


class IndexSetMaskMessenger(pyro.poutine.messenger.Messenger):
Expand Down Expand Up @@ -50,51 +45,3 @@ class SelectFactual(IndexSetMaskMessenger):
@property
def indices(self) -> IndexSet:
return IndexSet(**{f.name: {0} for f in get_index_plates().values()})


class FactualConditioningReparam(pyro.infer.reparam.reparam.Reparam):
@pyro.poutine.infer_config(config_fn=lambda msg: {"_cf_conditioned": True})
def apply(self, msg):
with SelectFactual() as fw:
fv = pyro.sample(msg["name"] + "_factual", msg["fn"], obs=msg["value"])

with SelectCounterfactual() as cw:
cv = pyro.sample(msg["name"] + "_counterfactual", msg["fn"])

event_dim = len(msg["fn"].event_shape)
new_value = merge({fw.indices: fv, cw.indices: cv}, event_dim=event_dim)
new_fn = pyro.distributions.Delta(new_value, event_dim=event_dim).mask(False)
return {"fn": new_fn, "value": new_value, "is_observed": msg["is_observed"]}


class FactualConditioning(pyro.infer.reparam.strategies.Strategy):
@staticmethod
def _expand_msg_value(msg: dict) -> None:
_custom_init = getattr(msg["value"], "_pyro_custom_init", False)
msg["value"] = msg["value"].expand(
torch.broadcast_shapes(
msg["fn"].batch_shape + msg["fn"].event_shape, msg["value"].shape
)
)
msg["value"]._pyro_custom_init = _custom_init

def configure(self, msg) -> Optional[FactualConditioningReparam]:
if (
not msg["is_observed"]
or pyro.poutine.util.site_is_subsample(msg)
or msg["infer"].get("_cf_conditioned", False) # don't apply recursively
):
return None

if msg["is_observed"] and msg["value"] is not None:
# XXX slightly gross workaround that mutates the msg in place to avoid
# triggering overzealous validation logic in pyro.poutine.reparam
# that uses cheaper tensor shape and identity equality checks as
# a conservative proxy for an expensive tensor value equality check.
# (see https://github.com/pyro-ppl/pyro/blob/685c7adee65bbcdd6bd6c84c834a0a460f2224eb/pyro/poutine/reparam_messenger.py#L99) # noqa: E501
# This workaround is correct because FactualConditioningReparam does not change
# the values of the observation, it just packs counterfactual values around it;
# the equality check being approximated by that logic would still pass.
self._expand_msg_value(msg)

return FactualConditioningReparam()
2 changes: 0 additions & 2 deletions tests/test_mediation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
MultiWorldCounterfactual,
TwinWorldCounterfactual,
)
from causal_pyro.counterfactual.selection import FactualConditioning
from causal_pyro.query.do_messenger import DoMessenger, do

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -191,7 +190,6 @@ def direct_effect(model, x, x_prime, w_obs, x_obs, z_obs, y_obs) -> Callable:
y_obs = torch.randn(N)

extended_model = direct_effect(model, x, x_prime, w_obs, x_obs, z_obs, y_obs)
extended_model = FactualConditioning()(extended_model)

with MultiWorldCounterfactual(-2):
W, X, Z, Y = extended_model()
Expand Down

0 comments on commit 8d3ec11

Please sign in to comment.