Skip to content

Commit

Permalink
Merge pull request #237 from QuantEcon/add_nodes_states
Browse files Browse the repository at this point in the history
Add states/nodes to MarkovChain/DiGraph
  • Loading branch information
jstac committed Apr 12, 2016
2 parents 3a4265c + 05e8ad8 commit 4eade8d
Show file tree
Hide file tree
Showing 4 changed files with 487 additions and 42 deletions.
99 changes: 87 additions & 12 deletions quantecon/graph_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@
from fractions import gcd


# Decorator for *_components properties
def annotate_nodes(func):
def new_func(self):
list_of_components = func(self)
if self.node_labels is not None:
return [self.node_labels[c] for c in list_of_components]
return list_of_components
return new_func


class DiGraph(object):
r"""
Class for a directed graph. It stores useful information about the
Expand All @@ -27,6 +37,11 @@ class DiGraph(object):
weighted : bool, optional(default=False)
Whether to treat `adj_matrix` as a weighted adjacency matrix.
node_labels : array_like(default=None)
Array_like of length n containing the labels associated with the
nodes, which must be homogeneous in type. If None, the labels
default to integers 0 through n-1.
Attributes
----------
csgraph : scipy.sparse.csr_matrix
Expand All @@ -38,16 +53,26 @@ class DiGraph(object):
num_strongly_connected_components : int
The number of the strongly connected components.
strongly_connected_components : list(ndarray(int))
strongly_connected_components_indices : list(ndarray(int))
List of numpy arrays containing the indices of the strongly
connected components.
strongly_connected_components : list(ndarray)
List of numpy arrays containing the strongly connected
components.
components, where the nodes are annotated with their labels (if
`node_labels` is not None).
num_sink_strongly_connected_components : int
The number of the sink strongly connected components.
sink_strongly_connected_components : list(ndarray(int))
sink_strongly_connected_components_indices : list(ndarray(int))
List of numpy arrays containing the indices of the sink strongly
connected components.
sink_strongly_connected_components : list(ndarray)
List of numpy arrays containing the sink strongly connected
components.
components, where the nodes are annotated with their labels (if
`node_labels` is not None).
is_aperiodic : bool
Indicate whether the digraph is aperiodic.
Expand All @@ -56,8 +81,14 @@ class DiGraph(object):
The period of the digraph. Defined only for a strongly connected
digraph.
cyclic_components : list(ndarray(int))
List of numpy arrays containing the cyclic components.
cyclic_components_indices : list(ndarray(int))
List of numpy arrays containing the indices of the cyclic
components.
cyclic_components : list(ndarray)
List of numpy arrays containing the cyclic components, where the
nodes are annotated with their labels (if `node_labels` is not
None).
References
----------
Expand All @@ -70,7 +101,7 @@ class DiGraph(object):
"""

def __init__(self, adj_matrix, weighted=False):
def __init__(self, adj_matrix, weighted=False, node_labels=None):
if weighted:
dtype = None
else:
Expand All @@ -83,6 +114,9 @@ def __init__(self, adj_matrix, weighted=False):

self.n = n # Number of nodes

# Call the setter method
self.node_labels = node_labels

self._num_scc = None
self._scc_proj = None
self._sink_scc_labels = None
Expand All @@ -95,6 +129,26 @@ def __repr__(self):
def __str__(self):
return "Directed Graph:\n - n(number of nodes): {n}".format(n=self.n)

@property
def node_labels(self):
return self._node_labels

@node_labels.setter
def node_labels(self, values):
if values is None:
self._node_labels = None
else:
values = np.asarray(values)
if (values.ndim < 1) or (values.shape[0] != self.n):
raise ValueError(
'node_labels must be an array_like of length n'
)
if np.issubdtype(values.dtype, np.object_):
raise ValueError(
'data in node_labels must be homogeneous in type'
)
self._node_labels = values

def _find_scc(self):
"""
Set ``self._num_scc`` and ``self._scc_proj``
Expand Down Expand Up @@ -170,21 +224,31 @@ def num_sink_strongly_connected_components(self):
return len(self.sink_scc_labels)

@property
def strongly_connected_components(self):
def strongly_connected_components_indices(self):
if self.is_strongly_connected:
return [np.arange(self.n)]
else:
return [np.where(self.scc_proj == k)[0]
for k in range(self.num_strongly_connected_components)]

@property
def sink_strongly_connected_components(self):
@annotate_nodes
def strongly_connected_components(self):
return self.strongly_connected_components_indices

@property
def sink_strongly_connected_components_indices(self):
if self.is_strongly_connected:
return [np.arange(self.n)]
else:
return [np.where(self.scc_proj == k)[0]
for k in self.sink_scc_labels.tolist()]

@property
@annotate_nodes
def sink_strongly_connected_components(self):
return self.sink_strongly_connected_components_indices

def _compute_period(self):
"""
Set ``self._period`` and ``self._cyclic_components_proj``.
Expand Down Expand Up @@ -256,13 +320,18 @@ def is_aperiodic(self):
return (self.period == 1)

@property
def cyclic_components(self):
def cyclic_components_indices(self):
if self.is_aperiodic:
return [np.arange(self.n)]
else:
return [np.where(self._cyclic_components_proj == k)[0]
for k in range(self.period)]

@property
@annotate_nodes
def cyclic_components(self,):
return self.cyclic_components_indices

def subgraph(self, nodes):
"""
Return the subgraph consisting of the given nodes and edges
Expand All @@ -271,7 +340,7 @@ def subgraph(self, nodes):
Parameters
----------
nodes : array_like(int, ndim=1)
Array of nodes.
Array of node indices.
Returns
-------
Expand All @@ -282,7 +351,13 @@ def subgraph(self, nodes):
adj_matrix = self.csgraph[nodes, :][:, nodes]

weighted = True # To copy the dtype
return DiGraph(adj_matrix, weighted=weighted)

if self.node_labels is not None:
node_labels = self.node_labels[nodes]
else:
node_labels = None

return DiGraph(adj_matrix, weighted=weighted, node_labels=node_labels)


def _csr_matrix_indices(S):
Expand Down
Loading

0 comments on commit 4eade8d

Please sign in to comment.