Skip to content

Commit

Permalink
[FEA] Generalized Adjustment Criterion (#1292)
Browse files Browse the repository at this point in the history
* This PR adds support for identifying generalized (non-backdoor) adjustment sets. Specifically, it adds support for finding a minimal adjustment set if one exists (it is guaranteed to find a set if one does exist). Ongoing work in the pywhy-graphs library to enumerate all m-separating sets in causal graphs will later unlock the ability to enumerate all generalized adjustment sets.

Signed-off-by: Nicholas Parente <parentenickj@gmail.com>

* adding default case

Signed-off-by: Nicholas Parente <parentenickj@gmail.com>

* adding minimal test

Signed-off-by: Nicholas Parente <parentenickj@gmail.com>

* poe format

Signed-off-by: Nicholas Parente <parentenickj@gmail.com>

* adding test, throwing on unsupported

Signed-off-by: Nicholas Parente <parentenickj@gmail.com>

* tweaks

Signed-off-by: Nicholas Parente <parentenickj@gmail.com>

* dependency bump

Signed-off-by: Nicholas Parente <parentenickj@gmail.com>

* delete misc files

Signed-off-by: Nicholas Parente <parentenickj@gmail.com>

* fix dictionary mapping

Signed-off-by: Nicholas Parente <parentenickj@gmail.com>

* make test check python version

Signed-off-by: Nicholas Parente <parentenickj@gmail.com>

* adding another test

Signed-off-by: Nicholas Parente <parentenickj@gmail.com>

* adding docs

Signed-off-by: Nicholas Parente <parentenickj@gmail.com>

* restore notebooks I dont want to change

Signed-off-by: Nicholas Parente <parentenickj@gmail.com>

* remove extraneous comment

Signed-off-by: Nicholas Parente <parentenickj@gmail.com>

* remove comment and print statement from example notebook

Signed-off-by: Nicholas Parente <parentenickj@gmail.com>

* add comma

Signed-off-by: Nicholas Parente <parentenickj@gmail.com>

* address typos

Signed-off-by: Nicholas Parente <parentenickj@gmail.com>

---------

Signed-off-by: Nicholas Parente <parentenickj@gmail.com>
  • Loading branch information
nparent1 authored Jan 21, 2025
1 parent ffb761f commit 3114151
Show file tree
Hide file tree
Showing 15 changed files with 740 additions and 108 deletions.

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions dowhy/causal_identifier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
AutoIdentifier,
BackdoorAdjustment,
EstimandType,
construct_backdoor_estimand,
GeneralizedAdjustment,
construct_adjustment_estimand,
construct_frontdoor_estimand,
construct_iv_estimand,
identify_effect_auto,
Expand All @@ -16,11 +17,12 @@
"identify_effect_auto",
"identify_effect_id",
"BackdoorAdjustment",
"GeneralizedAdjustment",
"EstimandType",
"IdentifiedEstimand",
"IDIdentifier",
"identify_effect",
"construct_backdoor_estimand",
"construct_adjustment_estimand",
"construct_frontdoor_estimand",
"construct_iv_estimand",
]
29 changes: 29 additions & 0 deletions dowhy/causal_identifier/adjustment_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
class AdjustmentSet:
"""Class for storing an adjustment set."""

BACKDOOR = "backdoor"
# General adjustment sets generalize backdoor sets, but we will differentiate
# between the two given the ubiquity of the backdoor criterion.
GENERAL = "general"

def __init__(
self,
adjustment_type,
adjustment_variables,
num_paths_blocked_by_observed_nodes=None,
):
self.adjustment_type = adjustment_type
self.adjustment_variables = adjustment_variables
self.num_paths_blocked_by_observed_nodes = num_paths_blocked_by_observed_nodes

def get_adjustment_type(self):
"""Return the technique associated with this adjustment set (backdoor, etc.)"""
return self.adjustment_type

def get_adjustment_variables(self):
"""Return a list containing the adjustment variables"""
return self.adjustment_variables

def get_num_paths_blocked_by_observed_nodes(self):
"""Return the number of paths blocked by observed nodes (optional)"""
return self.num_paths_blocked_by_observed_nodes
187 changes: 142 additions & 45 deletions dowhy/causal_identifier/auto_identifier.py

Large diffs are not rendered by default.

13 changes: 8 additions & 5 deletions dowhy/causal_identifier/backdoor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import networkx as nx

from dowhy.causal_identifier.adjustment_set import AdjustmentSet
from dowhy.utils.graph_operations import adjacency_matrix_to_adjacency_list


Expand Down Expand Up @@ -113,11 +114,13 @@ def get_backdoor_vars(self):
self._path_search(adjlist, node1, node2, path_dict)
if len(path_dict) != 0:
obj = HittingSetAlgorithm(path_dict[(node1, node2)].get_condition_vars(), self._colliders)

backdoor_set = {}
backdoor_set["backdoor_set"] = tuple(obj.find_set())
backdoor_set["num_paths_blocked_by_observed_nodes"] = obj.num_sets()
backdoor_sets.append(backdoor_set)
backdoor_sets.append(
AdjustmentSet(
adjustment_type=AdjustmentSet.BACKDOOR,
adjustment_variables=tuple(obj.find_set()),
num_paths_blocked_by_observed_nodes=obj.num_sets(),
)
)

return backdoor_sets

Expand Down
16 changes: 16 additions & 0 deletions dowhy/causal_identifier/identified_estimand.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,22 @@ def __init__(
estimand_type=None,
estimands=None,
backdoor_variables=None,
general_adjustment_variables=None,
instrumental_variables=None,
frontdoor_variables=None,
mediator_variables=None,
mediation_first_stage_confounders=None,
mediation_second_stage_confounders=None,
default_backdoor_id=None,
default_adjustment_set_id=None,
identifier_method=None,
no_directed_path=False,
):
self.identifier = identifier
self.treatment_variable = parse_state(treatment_variable)
self.outcome_variable = parse_state(outcome_variable)
self.backdoor_variables = backdoor_variables
self.general_adjustment_variables = general_adjustment_variables
self.instrumental_variables = parse_state(instrumental_variables)
self.frontdoor_variables = parse_state(frontdoor_variables)
self.mediator_variables = parse_state(mediator_variables)
Expand All @@ -38,6 +41,7 @@ def __init__(
self.estimand_type = estimand_type
self.estimands = estimands
self.default_backdoor_id = default_backdoor_id
self.default_adjustment_set_id = default_adjustment_set_id
self.identifier_method = identifier_method
self.no_directed_path = no_directed_path

Expand Down Expand Up @@ -78,6 +82,13 @@ def get_instrumental_variables(self):
"""Return a list containing the instrumental variables (if present)"""
return self.instrumental_variables

def get_general_adjustment_variables(self, key: Optional[str] = None):
"""Return a list containing general adjustment variables."""
if key is None:
return self.general_adjustment_variables[self.default_adjustment_set_id]
else:
return self.general_adjustment_variables[key]

def __deepcopy__(self, memo):
return IdentifiedEstimand(
self.identifier, # not deep copied
Expand All @@ -86,10 +97,12 @@ def __deepcopy__(self, memo):
estimand_type=copy.deepcopy(self.estimand_type),
estimands=copy.deepcopy(self.estimands),
backdoor_variables=copy.deepcopy(self.backdoor_variables),
general_adjustment_variables=copy.deepcopy(self.general_adjustment_variables),
instrumental_variables=copy.deepcopy(self.instrumental_variables),
frontdoor_variables=copy.deepcopy(self.frontdoor_variables),
mediator_variables=copy.deepcopy(self.mediator_variables),
default_backdoor_id=copy.deepcopy(self.default_backdoor_id),
default_adjustment_set_id=copy.deepcopy(self.default_adjustment_set_id),
identifier_method=copy.deepcopy(self.identifier_method),
)

Expand All @@ -112,6 +125,9 @@ def __str__(self, only_target_estimand: bool = False, show_all_backdoor_sets: bo
# Just show the default backdoor set
if k.startswith("backdoor") and k != "backdoor":
continue
# Just show the default generalized adjustment set
if k.startswith("general") and k != "general_adjustment":
continue
if only_target_estimand and k != self.identifier_method:
continue
s += "\n### Estimand : {0}\n".format(i)
Expand Down
54 changes: 54 additions & 0 deletions dowhy/graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""This module defines the fundamental interfaces and functions related to causal graphs."""

import copy
import itertools
import logging
import re
Expand Down Expand Up @@ -187,13 +188,66 @@ def is_blocked(graph: nx.DiGraph, path, conditioned_nodes):
return False


def get_ancestors(graph: nx.DiGraph, nodes):
ancestors = set()
for node_name in nodes:
ancestors = ancestors.union(set(nx.ancestors(graph, node_name)))
return ancestors


def get_descendants(graph: nx.DiGraph, nodes):
descendants = set()
for node_name in nodes:
descendants = descendants.union(set(nx.descendants(graph, node_name)))
return descendants


def get_proper_causal_path_nodes(graph: nx.DiGraph, action_nodes, outcome_nodes):
"""Method to get the proper causal path nodes, as described in van der Zander et al. "Constructing Separators and
Adjustment Sets in Ancestral Graphs", Section 4.1. We cannot use do_surgery() since we require deep copies of the given graph.
:param graph: the causal graph in question
:param action_nodes: the action nodes
:param outcome_nodes: the outcome nodes
:returns: the set of nodes that lie on proper causal paths from X to Y
"""

# 1) Create a pair of modified graphs by removing inbound and outbound arrows from the action nodes, respectively.
graph_post_interv = copy.deepcopy(graph) # remove incoming arrows to our action nodes
edges_to_remove = [(u, v) for u, v in graph_post_interv.in_edges(action_nodes)]
graph_post_interv.remove_edges_from(edges_to_remove)
graph_with_action_nodes_as_sinks = copy.deepcopy(graph) # remove outbound arrows from our action nodes
edges_to_remove = [(u, v) for u, v in graph_with_action_nodes_as_sinks.out_edges(action_nodes)]
graph_with_action_nodes_as_sinks.remove_edges_from(edges_to_remove)

# 2) Use the modified graphs to identify the nodes which lie on proper causal paths from the
# action nodes to the outcome nodes.
de_x = get_descendants(graph_post_interv, action_nodes).union(action_nodes)
an_y = get_ancestors(graph_with_action_nodes_as_sinks, outcome_nodes).union(outcome_nodes)
return (set(de_x) - set(action_nodes)) & an_y


def get_proper_backdoor_graph(graph: nx.DiGraph, action_nodes, outcome_nodes):
"""Method to get the proper backdoor graph from a causal graph, as described in van der Zander et al. "Constructing Separators and
Adjustment Sets in Ancestral Graphs", Section 4.1. We cannot use do_surgery() since we require deep copies of the given graph.
:param graph: the causal graph in question
:param action_nodes: the action nodes
:param outcome_nodes: the outcome nodes
:returns: a new graph which is the proper backdoor graph of the original
"""

# First we can just call get_proper_causal_path_nodes, then
# we remove edges from the action_nodes to the proper causal path nodes.
graph_pbd = copy.deepcopy(graph)
graph_pbd.remove_edges_from(
[(u, v) for u in action_nodes for v in get_proper_causal_path_nodes(graph, action_nodes, outcome_nodes)]
)
return graph_pbd


def check_dseparation(graph: nx.DiGraph, nodes1, nodes2, nodes3, new_graph=None, dseparation_algo="default"):
if dseparation_algo == "default":
if new_graph is None:
Expand Down
59 changes: 24 additions & 35 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ pandas = [
{version = "<2.0", python = "<3.9"},
{version = ">1.0", python = ">=3.9"}
]
networkx = ">=2.8.5"
networkx = [
{version = ">=3.3", python = ">=3.10"},
{version = ">=2.8.5", python = "<3.10"}
]
sympy = ">=1.10.1"
scikit-learn = ">1.0"
pydot = { version = "^1.4.2", optional = true }
Expand Down
41 changes: 38 additions & 3 deletions tests/causal_identifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from dowhy.graph import build_graph_from_str

from .example_graphs import TEST_FRONTDOOR_GRAPH_SOLUTIONS, TEST_GRAPH_SOLUTIONS
from .example_graphs import (
TEST_FRONTDOOR_GRAPH_SOLUTIONS,
TEST_GRAPH_SOLUTIONS,
TEST_GRAPH_SOLUTIONS_COMPLETE_ADJUSTMENT,
)


class IdentificationTestGraphSolution(object):
Expand Down Expand Up @@ -34,15 +38,39 @@ def __init__(
observed_variables,
valid_frontdoor_sets,
invalid_frontdoor_sets,
action_nodes=None,
outcome_nodes=None,
):
if outcome_nodes is None:
outcome_nodes = ["Y"]
if action_nodes is None:
action_nodes = ["X"]
self.graph = build_graph_from_str(graph_str)
self.action_nodes = ["X"]
self.outcome_nodes = ["Y"]
self.action_nodes = action_nodes
self.outcome_nodes = outcome_nodes
self.observed_nodes = observed_variables
self.valid_frontdoor_sets = valid_frontdoor_sets
self.invalid_frontdoor_sets = invalid_frontdoor_sets


class IdentificationTestGeneralCovariateAdjustmentGraphSolution(object):
def __init__(
self,
graph_str,
observed_variables,
action_nodes,
outcome_nodes,
minimal_adjustment_sets,
exhaustive_adjustment_sets=None,
):
self.graph = build_graph_from_str(graph_str)
self.action_nodes = action_nodes
self.outcome_nodes = outcome_nodes
self.observed_nodes = observed_variables
self.minimal_adjustment_sets = minimal_adjustment_sets
self.exhaustive_adjustment_sets = exhaustive_adjustment_sets


@pytest.fixture(params=TEST_GRAPH_SOLUTIONS.keys())
def example_graph_solution(request):
return IdentificationTestGraphSolution(**TEST_GRAPH_SOLUTIONS[request.param])
Expand All @@ -51,3 +79,10 @@ def example_graph_solution(request):
@pytest.fixture(params=TEST_FRONTDOOR_GRAPH_SOLUTIONS.keys())
def example_frontdoor_graph_solution(request):
return IdentificationTestFrontdoorGraphSolution(**TEST_FRONTDOOR_GRAPH_SOLUTIONS[request.param])


@pytest.fixture(params=TEST_GRAPH_SOLUTIONS_COMPLETE_ADJUSTMENT.keys())
def example_complete_adjustment_graph_solution(request):
return IdentificationTestGeneralCovariateAdjustmentGraphSolution(
**TEST_GRAPH_SOLUTIONS_COMPLETE_ADJUSTMENT[request.param]
)
Loading

0 comments on commit 3114151

Please sign in to comment.