-
Notifications
You must be signed in to change notification settings - Fork 33
/
pcalg.py
313 lines (276 loc) · 11.3 KB
/
pcalg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""A graph generator based on the PC algorithm [Kalisch2007].
[Kalisch2007] Markus Kalisch and Peter Bhlmann. Estimating
high-dimensional directed acyclic graphs with the pc-algorithm. In The
Journal of Machine Learning Research, Vol. 8, pp. 613-636, 2007.
License: BSD
"""
from __future__ import print_function
from itertools import combinations, permutations
import logging
import networkx as nx
_logger = logging.getLogger(__name__)
def _create_complete_graph(node_ids):
"""Create a complete graph from the list of node ids.
Args:
node_ids: a list of node ids
Returns:
An undirected graph (as a networkx.Graph)
"""
g = nx.Graph()
g.add_nodes_from(node_ids)
for (i, j) in combinations(node_ids, 2):
g.add_edge(i, j)
return g
def estimate_skeleton(indep_test_func, data_matrix, alpha, **kwargs):
"""Estimate a skeleton graph from the statistis information.
Args:
indep_test_func: the function name for a conditional
independency test.
data_matrix: data (as a numpy array).
alpha: the significance level.
kwargs:
'max_reach': maximum value of l (see the code). The
value depends on the underlying distribution.
'method': if 'stable' given, use stable-PC algorithm
(see [Colombo2014]).
'init_graph': initial structure of skeleton graph
(as a networkx.Graph). If not specified,
a complete graph is used.
'fixed_edges': Undirected edges marked here are not changed
(as a networkx.Graph). If not specified,
an empty graph is used.
other parameters may be passed depending on the
indep_test_func()s.
Returns:
g: a skeleton graph (as a networkx.Graph).
sep_set: a separation set (as an 2D-array of set()).
[Colombo2014] Diego Colombo and Marloes H Maathuis. Order-independent
constraint-based causal structure learning. In The Journal of Machine
Learning Research, Vol. 15, pp. 3741-3782, 2014.
"""
def method_stable(kwargs):
return ('method' in kwargs) and kwargs['method'] == "stable"
node_ids = range(data_matrix.shape[1])
node_size = data_matrix.shape[1]
sep_set = [[set() for i in range(node_size)] for j in range(node_size)]
if 'init_graph' in kwargs:
g = kwargs['init_graph']
if not isinstance(g, nx.Graph):
raise ValueError
elif not g.number_of_nodes() == len(node_ids):
raise ValueError('init_graph not matching data_matrix shape')
for (i, j) in combinations(node_ids, 2):
if not g.has_edge(i, j):
sep_set[i][j] = None
sep_set[j][i] = None
else:
g = _create_complete_graph(node_ids)
fixed_edges = set()
if 'fixed_edges' in kwargs:
_fixed_edges = kwargs['fixed_edges']
if not isinstance(_fixed_edges, nx.Graph):
raise ValueError
if not _fixed_edges.number_of_nodes() == len(node_ids):
raise ValueError('fixed_edges not matching data_matrix shape')
for (i, j) in _fixed_edges.edges:
fixed_edges.add((i, j))
fixed_edges.add((j, i))
l = 0
while True:
cont = False
remove_edges = []
for (i, j) in permutations(node_ids, 2):
if (i, j) in fixed_edges:
continue
adj_i = list(g.neighbors(i))
if j not in adj_i:
continue
else:
adj_i.remove(j)
if len(adj_i) >= l:
_logger.debug('testing %s and %s' % (i,j))
_logger.debug('neighbors of %s are %s' % (i, str(adj_i)))
if len(adj_i) < l:
continue
for k in combinations(adj_i, l):
_logger.debug('indep prob of %s and %s with subset %s'
% (i, j, str(k)))
p_val = indep_test_func(data_matrix, i, j, set(k),
**kwargs)
_logger.debug('p_val is %s' % str(p_val))
if p_val > alpha:
if g.has_edge(i, j):
_logger.debug('p: remove edge (%s, %s)' % (i, j))
if method_stable(kwargs):
remove_edges.append((i, j))
else:
g.remove_edge(i, j)
sep_set[i][j] |= set(k)
sep_set[j][i] |= set(k)
break
cont = True
l += 1
if method_stable(kwargs):
g.remove_edges_from(remove_edges)
if cont is False:
break
if ('max_reach' in kwargs) and (l > kwargs['max_reach']):
break
return (g, sep_set)
def estimate_cpdag(skel_graph, sep_set):
"""Estimate a CPDAG from the skeleton graph and separation sets
returned by the estimate_skeleton() function.
Args:
skel_graph: A skeleton graph (an undirected networkx.Graph).
sep_set: An 2D-array of separation set.
The contents look like something like below.
sep_set[i][j] = set([k, l, m])
Returns:
An estimated DAG.
"""
dag = skel_graph.to_directed()
node_ids = skel_graph.nodes()
for (i, j) in combinations(node_ids, 2):
adj_i = set(dag.successors(i))
if j in adj_i:
continue
adj_j = set(dag.successors(j))
if i in adj_j:
continue
if sep_set[i][j] is None:
continue
common_k = adj_i & adj_j
for k in common_k:
if k not in sep_set[i][j]:
if dag.has_edge(k, i):
_logger.debug('S: remove edge (%s, %s)' % (k, i))
dag.remove_edge(k, i)
if dag.has_edge(k, j):
_logger.debug('S: remove edge (%s, %s)' % (k, j))
dag.remove_edge(k, j)
def _has_both_edges(dag, i, j):
return dag.has_edge(i, j) and dag.has_edge(j, i)
def _has_any_edge(dag, i, j):
return dag.has_edge(i, j) or dag.has_edge(j, i)
def _has_one_edge(dag, i, j):
return ((dag.has_edge(i, j) and (not dag.has_edge(j, i))) or
(not dag.has_edge(i, j)) and dag.has_edge(j, i))
def _has_no_edge(dag, i, j):
return (not dag.has_edge(i, j)) and (not dag.has_edge(j, i))
# For all the combination of nodes i and j, apply the following
# rules.
old_dag = dag.copy()
while True:
for (i, j) in permutations(node_ids, 2):
# Rule 1: Orient i-j into i->j whenever there is an arrow k->i
# such that k and j are nonadjacent.
#
# Check if i-j.
if _has_both_edges(dag, i, j):
# Look all the predecessors of i.
for k in dag.predecessors(i):
# Skip if there is an arrow i->k.
if dag.has_edge(i, k):
continue
# Skip if k and j are adjacent.
if _has_any_edge(dag, k, j):
continue
# Make i-j into i->j
_logger.debug('R1: remove edge (%s, %s)' % (j, i))
dag.remove_edge(j, i)
break
# Rule 2: Orient i-j into i->j whenever there is a chain
# i->k->j.
#
# Check if i-j.
if _has_both_edges(dag, i, j):
# Find nodes k where k is i->k.
succs_i = set()
for k in dag.successors(i):
if not dag.has_edge(k, i):
succs_i.add(k)
# Find nodes j where j is k->j.
preds_j = set()
for k in dag.predecessors(j):
if not dag.has_edge(j, k):
preds_j.add(k)
# Check if there is any node k where i->k->j.
if len(succs_i & preds_j) > 0:
# Make i-j into i->j
_logger.debug('R2: remove edge (%s, %s)' % (j, i))
dag.remove_edge(j, i)
# Rule 3: Orient i-j into i->j whenever there are two chains
# i-k->j and i-l->j such that k and l are nonadjacent.
#
# Check if i-j.
if _has_both_edges(dag, i, j):
# Find nodes k where i-k.
adj_i = set()
for k in dag.successors(i):
if dag.has_edge(k, i):
adj_i.add(k)
# For all the pairs of nodes in adj_i,
for (k, l) in combinations(adj_i, 2):
# Skip if k and l are adjacent.
if _has_any_edge(dag, k, l):
continue
# Skip if not k->j.
if dag.has_edge(j, k) or (not dag.has_edge(k, j)):
continue
# Skip if not l->j.
if dag.has_edge(j, l) or (not dag.has_edge(l, j)):
continue
# Make i-j into i->j.
_logger.debug('R3: remove edge (%s, %s)' % (j, i))
dag.remove_edge(j, i)
break
# Rule 4: Orient i-j into i->j whenever there are two chains
# i-k->l and k->l->j such that k and j are nonadjacent.
#
# However, this rule is not necessary when the PC-algorithm
# is used to estimate a DAG.
if nx.is_isomorphic(dag, old_dag):
break
old_dag = dag.copy()
return dag
if __name__ == '__main__':
import networkx as nx
import numpy as np
from gsq.ci_tests import ci_test_bin, ci_test_dis
from gsq.gsq_testdata import bin_data, dis_data
# ch = logging.StreamHandler()
# ch.setLevel(logging.DEBUG)
# _logger.setLevel(logging.DEBUG)
# _logger.addHandler(ch)
dm = np.array(bin_data).reshape((5000, 5))
(g, sep_set) = estimate_skeleton(indep_test_func=ci_test_bin,
data_matrix=dm,
alpha=0.01)
g = estimate_cpdag(skel_graph=g, sep_set=sep_set)
g_answer = nx.DiGraph()
g_answer.add_nodes_from([0, 1, 2, 3, 4])
g_answer.add_edges_from([(0, 1), (2, 3), (3, 2), (3, 1),
(2, 4), (4, 2), (4, 1)])
print('Edges are:', g.edges(), end='')
if nx.is_isomorphic(g, g_answer):
print(' => GOOD')
else:
print(' => WRONG')
print('True edges should be:', g_answer.edges())
dm = np.array(dis_data).reshape((10000, 5))
(g, sep_set) = estimate_skeleton(indep_test_func=ci_test_dis,
data_matrix=dm,
alpha=0.01,
levels=[3,2,3,4,2])
g = estimate_cpdag(skel_graph=g, sep_set=sep_set)
g_answer = nx.DiGraph()
g_answer.add_nodes_from([0, 1, 2, 3, 4])
g_answer.add_edges_from([(0, 2), (1, 2), (1, 3), (4, 3)])
print('Edges are:', g.edges(), end='')
if nx.is_isomorphic(g, g_answer):
print(' => GOOD')
else:
print(' => WRONG')
print('True edges should be:', g_answer.edges())