12
12
from fractions import gcd
13
13
14
14
15
+ # Decorator for *_components properties
16
+ def annotate_nodes (func ):
17
+ def new_func (self ):
18
+ list_of_components = func (self )
19
+ if self .node_labels is not None :
20
+ return [self .node_labels [c ] for c in list_of_components ]
21
+ return list_of_components
22
+ return new_func
23
+
24
+
15
25
class DiGraph (object ):
16
26
r"""
17
27
Class for a directed graph. It stores useful information about the
@@ -27,6 +37,11 @@ class DiGraph(object):
27
37
weighted : bool, optional(default=False)
28
38
Whether to treat `adj_matrix` as a weighted adjacency matrix.
29
39
40
+ node_labels : array_like(default=None)
41
+ Array_like of length n containing the labels associated with the
42
+ nodes, which must be homogeneous in type. If None, the labels
43
+ default to integers 0 through n-1.
44
+
30
45
Attributes
31
46
----------
32
47
csgraph : scipy.sparse.csr_matrix
@@ -38,16 +53,26 @@ class DiGraph(object):
38
53
num_strongly_connected_components : int
39
54
The number of the strongly connected components.
40
55
41
- strongly_connected_components : list(ndarray(int))
56
+ strongly_connected_components_indices : list(ndarray(int))
57
+ List of numpy arrays containing the indices of the strongly
58
+ connected components.
59
+
60
+ strongly_connected_components : list(ndarray)
42
61
List of numpy arrays containing the strongly connected
43
- components.
62
+ components, where the nodes are annotated with their labels (if
63
+ `node_labels` is not None).
44
64
45
65
num_sink_strongly_connected_components : int
46
66
The number of the sink strongly connected components.
47
67
48
- sink_strongly_connected_components : list(ndarray(int))
68
+ sink_strongly_connected_components_indices : list(ndarray(int))
69
+ List of numpy arrays containing the indices of the sink strongly
70
+ connected components.
71
+
72
+ sink_strongly_connected_components : list(ndarray)
49
73
List of numpy arrays containing the sink strongly connected
50
- components.
74
+ components, where the nodes are annotated with their labels (if
75
+ `node_labels` is not None).
51
76
52
77
is_aperiodic : bool
53
78
Indicate whether the digraph is aperiodic.
@@ -56,8 +81,14 @@ class DiGraph(object):
56
81
The period of the digraph. Defined only for a strongly connected
57
82
digraph.
58
83
59
- cyclic_components : list(ndarray(int))
60
- List of numpy arrays containing the cyclic components.
84
+ cyclic_components_indices : list(ndarray(int))
85
+ List of numpy arrays containing the indices of the cyclic
86
+ components.
87
+
88
+ cyclic_components : list(ndarray)
89
+ List of numpy arrays containing the cyclic components, where the
90
+ nodes are annotated with their labels (if `node_labels` is not
91
+ None).
61
92
62
93
References
63
94
----------
@@ -70,7 +101,7 @@ class DiGraph(object):
70
101
71
102
"""
72
103
73
- def __init__ (self , adj_matrix , weighted = False ):
104
+ def __init__ (self , adj_matrix , weighted = False , node_labels = None ):
74
105
if weighted :
75
106
dtype = None
76
107
else :
@@ -83,6 +114,9 @@ def __init__(self, adj_matrix, weighted=False):
83
114
84
115
self .n = n # Number of nodes
85
116
117
+ # Call the setter method
118
+ self .node_labels = node_labels
119
+
86
120
self ._num_scc = None
87
121
self ._scc_proj = None
88
122
self ._sink_scc_labels = None
@@ -95,6 +129,26 @@ def __repr__(self):
95
129
def __str__ (self ):
96
130
return "Directed Graph:\n - n(number of nodes): {n}" .format (n = self .n )
97
131
132
+ @property
133
+ def node_labels (self ):
134
+ return self ._node_labels
135
+
136
+ @node_labels .setter
137
+ def node_labels (self , values ):
138
+ if values is None :
139
+ self ._node_labels = None
140
+ else :
141
+ values = np .asarray (values )
142
+ if (values .ndim < 1 ) or (values .shape [0 ] != self .n ):
143
+ raise ValueError (
144
+ 'node_labels must be an array_like of length n'
145
+ )
146
+ if np .issubdtype (values .dtype , np .object_ ):
147
+ raise ValueError (
148
+ 'data in node_labels must be homogeneous in type'
149
+ )
150
+ self ._node_labels = values
151
+
98
152
def _find_scc (self ):
99
153
"""
100
154
Set ``self._num_scc`` and ``self._scc_proj``
@@ -170,21 +224,31 @@ def num_sink_strongly_connected_components(self):
170
224
return len (self .sink_scc_labels )
171
225
172
226
@property
173
- def strongly_connected_components (self ):
227
+ def strongly_connected_components_indices (self ):
174
228
if self .is_strongly_connected :
175
229
return [np .arange (self .n )]
176
230
else :
177
231
return [np .where (self .scc_proj == k )[0 ]
178
232
for k in range (self .num_strongly_connected_components )]
179
233
180
234
@property
181
- def sink_strongly_connected_components (self ):
235
+ @annotate_nodes
236
+ def strongly_connected_components (self ):
237
+ return self .strongly_connected_components_indices
238
+
239
+ @property
240
+ def sink_strongly_connected_components_indices (self ):
182
241
if self .is_strongly_connected :
183
242
return [np .arange (self .n )]
184
243
else :
185
244
return [np .where (self .scc_proj == k )[0 ]
186
245
for k in self .sink_scc_labels .tolist ()]
187
246
247
+ @property
248
+ @annotate_nodes
249
+ def sink_strongly_connected_components (self ):
250
+ return self .sink_strongly_connected_components_indices
251
+
188
252
def _compute_period (self ):
189
253
"""
190
254
Set ``self._period`` and ``self._cyclic_components_proj``.
@@ -256,13 +320,18 @@ def is_aperiodic(self):
256
320
return (self .period == 1 )
257
321
258
322
@property
259
- def cyclic_components (self ):
323
+ def cyclic_components_indices (self ):
260
324
if self .is_aperiodic :
261
325
return [np .arange (self .n )]
262
326
else :
263
327
return [np .where (self ._cyclic_components_proj == k )[0 ]
264
328
for k in range (self .period )]
265
329
330
+ @property
331
+ @annotate_nodes
332
+ def cyclic_components (self ,):
333
+ return self .cyclic_components_indices
334
+
266
335
def subgraph (self , nodes ):
267
336
"""
268
337
Return the subgraph consisting of the given nodes and edges
@@ -271,7 +340,7 @@ def subgraph(self, nodes):
271
340
Parameters
272
341
----------
273
342
nodes : array_like(int, ndim=1)
274
- Array of nodes .
343
+ Array of node indices .
275
344
276
345
Returns
277
346
-------
@@ -282,7 +351,13 @@ def subgraph(self, nodes):
282
351
adj_matrix = self .csgraph [nodes , :][:, nodes ]
283
352
284
353
weighted = True # To copy the dtype
285
- return DiGraph (adj_matrix , weighted = weighted )
354
+
355
+ if self .node_labels is not None :
356
+ node_labels = self .node_labels [nodes ]
357
+ else :
358
+ node_labels = None
359
+
360
+ return DiGraph (adj_matrix , weighted = weighted , node_labels = node_labels )
286
361
287
362
288
363
def _csr_matrix_indices (S ):
0 commit comments