Skip to content

Commit 4eade8d

Browse files
committed
Merge pull request #237 from QuantEcon/add_nodes_states
Add states/nodes to MarkovChain/DiGraph
2 parents 3a4265c + 05e8ad8 commit 4eade8d

File tree

4 files changed

+487
-42
lines changed

4 files changed

+487
-42
lines changed

quantecon/graph_tools.py

Lines changed: 87 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@
1212
from fractions import gcd
1313

1414

15+
# Decorator for *_components properties
16+
def annotate_nodes(func):
17+
def new_func(self):
18+
list_of_components = func(self)
19+
if self.node_labels is not None:
20+
return [self.node_labels[c] for c in list_of_components]
21+
return list_of_components
22+
return new_func
23+
24+
1525
class DiGraph(object):
1626
r"""
1727
Class for a directed graph. It stores useful information about the
@@ -27,6 +37,11 @@ class DiGraph(object):
2737
weighted : bool, optional(default=False)
2838
Whether to treat `adj_matrix` as a weighted adjacency matrix.
2939
40+
node_labels : array_like(default=None)
41+
Array_like of length n containing the labels associated with the
42+
nodes, which must be homogeneous in type. If None, the labels
43+
default to integers 0 through n-1.
44+
3045
Attributes
3146
----------
3247
csgraph : scipy.sparse.csr_matrix
@@ -38,16 +53,26 @@ class DiGraph(object):
3853
num_strongly_connected_components : int
3954
The number of the strongly connected components.
4055
41-
strongly_connected_components : list(ndarray(int))
56+
strongly_connected_components_indices : list(ndarray(int))
57+
List of numpy arrays containing the indices of the strongly
58+
connected components.
59+
60+
strongly_connected_components : list(ndarray)
4261
List of numpy arrays containing the strongly connected
43-
components.
62+
components, where the nodes are annotated with their labels (if
63+
`node_labels` is not None).
4464
4565
num_sink_strongly_connected_components : int
4666
The number of the sink strongly connected components.
4767
48-
sink_strongly_connected_components : list(ndarray(int))
68+
sink_strongly_connected_components_indices : list(ndarray(int))
69+
List of numpy arrays containing the indices of the sink strongly
70+
connected components.
71+
72+
sink_strongly_connected_components : list(ndarray)
4973
List of numpy arrays containing the sink strongly connected
50-
components.
74+
components, where the nodes are annotated with their labels (if
75+
`node_labels` is not None).
5176
5277
is_aperiodic : bool
5378
Indicate whether the digraph is aperiodic.
@@ -56,8 +81,14 @@ class DiGraph(object):
5681
The period of the digraph. Defined only for a strongly connected
5782
digraph.
5883
59-
cyclic_components : list(ndarray(int))
60-
List of numpy arrays containing the cyclic components.
84+
cyclic_components_indices : list(ndarray(int))
85+
List of numpy arrays containing the indices of the cyclic
86+
components.
87+
88+
cyclic_components : list(ndarray)
89+
List of numpy arrays containing the cyclic components, where the
90+
nodes are annotated with their labels (if `node_labels` is not
91+
None).
6192
6293
References
6394
----------
@@ -70,7 +101,7 @@ class DiGraph(object):
70101
71102
"""
72103

73-
def __init__(self, adj_matrix, weighted=False):
104+
def __init__(self, adj_matrix, weighted=False, node_labels=None):
74105
if weighted:
75106
dtype = None
76107
else:
@@ -83,6 +114,9 @@ def __init__(self, adj_matrix, weighted=False):
83114

84115
self.n = n # Number of nodes
85116

117+
# Call the setter method
118+
self.node_labels = node_labels
119+
86120
self._num_scc = None
87121
self._scc_proj = None
88122
self._sink_scc_labels = None
@@ -95,6 +129,26 @@ def __repr__(self):
95129
def __str__(self):
96130
return "Directed Graph:\n - n(number of nodes): {n}".format(n=self.n)
97131

132+
@property
133+
def node_labels(self):
134+
return self._node_labels
135+
136+
@node_labels.setter
137+
def node_labels(self, values):
138+
if values is None:
139+
self._node_labels = None
140+
else:
141+
values = np.asarray(values)
142+
if (values.ndim < 1) or (values.shape[0] != self.n):
143+
raise ValueError(
144+
'node_labels must be an array_like of length n'
145+
)
146+
if np.issubdtype(values.dtype, np.object_):
147+
raise ValueError(
148+
'data in node_labels must be homogeneous in type'
149+
)
150+
self._node_labels = values
151+
98152
def _find_scc(self):
99153
"""
100154
Set ``self._num_scc`` and ``self._scc_proj``
@@ -170,21 +224,31 @@ def num_sink_strongly_connected_components(self):
170224
return len(self.sink_scc_labels)
171225

172226
@property
173-
def strongly_connected_components(self):
227+
def strongly_connected_components_indices(self):
174228
if self.is_strongly_connected:
175229
return [np.arange(self.n)]
176230
else:
177231
return [np.where(self.scc_proj == k)[0]
178232
for k in range(self.num_strongly_connected_components)]
179233

180234
@property
181-
def sink_strongly_connected_components(self):
235+
@annotate_nodes
236+
def strongly_connected_components(self):
237+
return self.strongly_connected_components_indices
238+
239+
@property
240+
def sink_strongly_connected_components_indices(self):
182241
if self.is_strongly_connected:
183242
return [np.arange(self.n)]
184243
else:
185244
return [np.where(self.scc_proj == k)[0]
186245
for k in self.sink_scc_labels.tolist()]
187246

247+
@property
248+
@annotate_nodes
249+
def sink_strongly_connected_components(self):
250+
return self.sink_strongly_connected_components_indices
251+
188252
def _compute_period(self):
189253
"""
190254
Set ``self._period`` and ``self._cyclic_components_proj``.
@@ -256,13 +320,18 @@ def is_aperiodic(self):
256320
return (self.period == 1)
257321

258322
@property
259-
def cyclic_components(self):
323+
def cyclic_components_indices(self):
260324
if self.is_aperiodic:
261325
return [np.arange(self.n)]
262326
else:
263327
return [np.where(self._cyclic_components_proj == k)[0]
264328
for k in range(self.period)]
265329

330+
@property
331+
@annotate_nodes
332+
def cyclic_components(self,):
333+
return self.cyclic_components_indices
334+
266335
def subgraph(self, nodes):
267336
"""
268337
Return the subgraph consisting of the given nodes and edges
@@ -271,7 +340,7 @@ def subgraph(self, nodes):
271340
Parameters
272341
----------
273342
nodes : array_like(int, ndim=1)
274-
Array of nodes.
343+
Array of node indices.
275344
276345
Returns
277346
-------
@@ -282,7 +351,13 @@ def subgraph(self, nodes):
282351
adj_matrix = self.csgraph[nodes, :][:, nodes]
283352

284353
weighted = True # To copy the dtype
285-
return DiGraph(adj_matrix, weighted=weighted)
354+
355+
if self.node_labels is not None:
356+
node_labels = self.node_labels[nodes]
357+
else:
358+
node_labels = None
359+
360+
return DiGraph(adj_matrix, weighted=weighted, node_labels=node_labels)
286361

287362

288363
def _csr_matrix_indices(S):

0 commit comments

Comments
 (0)