diff --git a/l2gv2/clustering.py b/l2gv2/clustering.py new file mode 100644 index 0000000..29d3d9e --- /dev/null +++ b/l2gv2/clustering.py @@ -0,0 +1,416 @@ +"""Graph clustering algorithms""" + +from math import log +from collections.abc import Iterable +from typing import Sequence + +import community +import torch +import pymetis +import numpy as np +import numba + +from l2gv2.network import TGraph, NPGraph +from l2gv2 import progress + + +def distributed_clustering(graph: TGraph, beta, rounds=None, patience=3, min_samples=2): + r""" + Distributed clustering algorithm + + Implements algorithm of [#dist]_ with gpu support + + Args: + graph: input graph + beta: :math:`\beta` value of the algorithm (controls the number of seeds) + rounds: number of iteration rounds (default: ``3*int(log(graph.num_nodes))``) + patience: number of rounds without label changes before early stopping (default: ``3``) + min_samples: minimum number of seed nodes (default: ``2``) + + .. Rubric:: Reference + + .. [#dist] H. Sun and L. Zanetti. “Distributed Graph Clustering and Sparsification”. + ACM Transactions on Parallel Computing 6.3 (2019), pp. 1–23. + doi: `10.1145/3364208 `_. + + """ + if rounds is None: + rounds = 3 * int(log(graph.num_nodes)) + strength = graph.strength + + # sample seed nodes + index = ( + torch.rand((graph.num_nodes,)) + < 1 / beta * log(1 / beta) * graph.strength / graph.strength.sum() + ) + while index.sum() < min_samples: + index = ( + torch.rand((graph.num_nodes,)) + < 1 / beta * log(1 / beta) * graph.strength / graph.strength.sum() + ) + seeds = torch.nonzero(index).flatten() + n_samples = seeds.numel() + + states = torch.zeros( + (graph.num_nodes, n_samples), dtype=torch.double, device=graph.device + ) + states[index, torch.arange(n_samples, device=graph.device)] = 1 / torch.sqrt( + strength[index] + ).to(dtype=torch.double) + clusters = torch.argmax(states, dim=1) + weights = graph.weights / torch.sqrt( + strength[graph.edge_index[0]] * strength[graph.edge_index[1]] + ) + weights = weights.to(dtype=torch.double) + r = 0 + num_same = 0 + while ( + r < rounds and num_same < patience + ): # keep iterating until clustering does not change for 'patience' rounds + r += 1 + states *= 0.5 + states.index_add_( + 0, + graph.edge_index[0], + 0.5 * states[graph.edge_index[1]] * weights.view(-1, 1), + ) + # states = ts.scatter(out=states, dim=0, index=graph.edge_index[0], + # src=0.5*states[graph.edge_index[1]]*weights.view(-1, 1)) + old_clusters = clusters + clusters = torch.argmax(states, dim=1) + if torch.equal(old_clusters, clusters): + num_same += 1 + else: + num_same = 0 + clusters[states[range(graph.num_nodes), clusters] == 0] = -1 + uc, clusters = torch.unique(clusters, return_inverse=True) + if uc[0] == -1: + clusters -= 1 + return clusters + + +def fennel_clustering( + graph, + num_clusters, + load_limit=1.1, + alpha=None, + gamma=1.5, + num_iters=1, + clusters=None, +): + """ TODO: docstring for fennel_clustering. """ + graph = graph.to(NPGraph) + + if clusters is None: + clusters = _fennel_clustering( + graph.edge_index, + graph.adj_index, + graph.num_nodes, + num_clusters, + load_limit, + alpha, + gamma, + num_iters, + ) + else: + clusters = _fennel_clustering( + graph.edge_index, + graph.adj_index, + graph.num_nodes, + num_clusters, + load_limit, + alpha, + gamma, + num_iters, + clusters, + ) + return torch.as_tensor(clusters) + +# pylint: disable=too-many-branches +@numba.njit +def _fennel_clustering( + edge_index, + adj_index, + num_nodes, + num_clusters, + load_limit=1.1, + alpha=None, + gamma=1.5, + num_iters=1, + clusters=np.empty(0, dtype=np.int64), +): + r""" + FENNEL single-pass graph clustering algorithm + + Implements the graph clustering algorithm of [#fennel]_. + + Args: + graph: input graph + + num_clusters: target number of clusters + + load_limit: maximum cluster size is + ``load_limit * graph.num_nodes / num_clusters`` (default: ``1.1``) + + alpha: :math:`\alpha` value for the algorithm (default as suggested in [#fennel]_) + + gamma: :math:`\gamma` value for the algorithm (default: 1.5) + + randomise_order: if ``True``, randomise order, else use breadth-first-search order. + + clusters: input clustering to refine (optional) + + num_iters: number of cluster assignment iterations (default: ``1``) + + Returns: + cluster index tensor + + References: + .. C. Tsourakakis et al. “FENNEL: Streaming Graph Partitioning for Massive Scale Graphs”. + In: Proceedings of the 7th ACM international conference on Web search and data mining. + WSDM'14 (2014) doi: `10.1145/2556195.2556213 `_. + + """ + if num_iters is None: + num_iters = 1 + + num_edges = edge_index.shape[1] + + if alpha is None: + alpha = num_edges * (num_clusters ** (gamma - 1)) / (num_nodes**gamma) + + partition_sizes = np.zeros(num_clusters, dtype=np.int64) + if clusters.size == 0: + clusters = np.full((num_nodes,), -1, dtype=np.int64) + else: + clusters = np.copy(clusters) + for index in clusters: + partition_sizes[index] += 1 + + load_limit *= num_nodes / num_clusters + + deltas = -alpha * gamma * (partition_sizes ** (gamma - 1)) + + with numba.objmode: + pbar = progress.reset(num_nodes) + + for it in range(num_iters): + not_converged = 0 + + progress_it = 0 + for i in range(num_nodes): + cluster_indices = np.empty( + (adj_index[i + 1] - adj_index[i],), dtype=np.int64 + ) + for ni, index in enumerate(range(adj_index[i], adj_index[i + 1])): + cluster_indices[ni] = clusters[edge_index[1, index]] + old_cluster = clusters[i] + if old_cluster >= 0: + partition_sizes[old_cluster] -= 1 + cluster_indices = cluster_indices[cluster_indices >= 0] + + if cluster_indices.size > 0: + c_size = np.zeros(num_clusters, dtype=np.int64) + for index in cluster_indices: + c_size[index] += 1 + ind = np.argmax(deltas + c_size) + else: + ind = np.argmax(deltas) + clusters[i] = ind + partition_sizes[ind] += 1 + if partition_sizes[ind] == load_limit: + deltas[ind] = -np.inf + else: + deltas[ind] = -alpha * gamma * (partition_sizes[ind] ** (gamma - 1)) + not_converged += ind != old_cluster + + if i % 10000 == 0 and i > 0: + progress_it = i + with numba.objmode: + progress.update(pbar, 10000) + with numba.objmode: + progress.update(pbar, num_nodes - progress_it) + + print("iteration: " + str(it) + ", not converged: " + str(not_converged)) + + if not_converged == 0: + print(f"converged after {it} iterations.") + break + with numba.objmode: + progress.close(pbar) + + return clusters +# pylint: enable=too-many-branches + +def louvain_clustering(graph: TGraph, *args, **kwargs): + r""" + Implements clustering using the Louvain [#l]_ algorithm for modularity optimisation + + Args: + graph: input graph + + Returns: + partition tensor + + This is a minimal wrapper around :py:func:`community.best_partition` from the + `python-louvain `_ package. Any other + arguments provided are passed through. + + References: + .. V. D. Blondel et al. + “Fast unfolding of communities in large networks”. + Journal of Statistical Mechanics: Theory and Experiment 2008.10 (2008), P10008. + DOI: `10.1088/1742-5468/2008/10/P10008` + + """ + # pylint: disable=no-member + clusters = community.best_partition( + graph.to_networkx().to_undirected(), *args, **kwargs + ) + # pylint: enable=no-member + return torch.tensor([clusters[i] for i in range(graph.num_nodes)], dtype=torch.long) + + +def metis_clustering(graph: TGraph, num_clusters): + """ + Implements clustering using metis + + Args: + graph: input graph + num_clusters: number of cluster + + Returns: + partition tensor + + This uses the `pymetis `_ package + + References: + .. [#metis] “A Fast and Highly Quality Multilevel Scheme for Partitioning Irregular Graphs”. + George Karypis and Vipin Kumar. + SIAM Journal on Scientific Computing, Vol. 20, No. 1, pp. 359—392, 1999. + """ + graph = graph.to(NPGraph) + _, memberships = pymetis.part_graph( + num_clusters, + adjncy=graph.edge_index[1], + xadj=graph.adj_index, + eweights=graph.edge_attr, + ) + return torch.as_tensor(memberships, dtype=torch.long, device=graph.device) + + +def spread_clustering(graph, num_clusters, max_degree_init=True): + """ TODO: docstring for spread_clustering. """ + + clusters = torch.full((graph.num_nodes,), -1, dtype=torch.long, device=graph.device) + if max_degree_init: + seeds = torch.topk(torch.as_tensor(graph.degree), k=num_clusters).indices + else: + seeds = torch.multinomial( + torch.as_tensor(graph.degree), num_clusters, replacement=False + ) + + clusters[seeds] = torch.arange(num_clusters) + spread_weights = torch.zeros( + (num_clusters, graph.num_nodes), dtype=torch.double, device=graph.device + ) + spread_weights[:, seeds] = -1 + unassigned = clusters < 0 + for seed in seeds: + c = clusters[seed] + inds, weights = graph.adj_weighted(seed) + keep = unassigned[inds] + spread_weights[c, inds[keep]] += weights[keep] / graph.strength[inds[keep]] + + num_unassigned = graph.num_nodes - num_clusters + + while num_unassigned > 0: + current_progress = False + for c in range(num_clusters): + node = torch.argmax(spread_weights[c]) + if spread_weights[c, node] > 0: + current_progress = True + # make sure node is actually connected to cluster + clusters[node] = c + spread_weights[:, node] = -1 # should not be chosen again + unassigned[node] = False + num_unassigned -= 1 + inds, weights = graph.adj_weighted(node) + keep = unassigned[inds] + spread_weights[c, inds[keep]] += ( + weights[keep] / graph.strength[inds[keep]] + ) + if not current_progress: + print("increasing number of clusters due to disconnected components") + unassigned_nodes = torch.nonzero(unassigned).ravel() + if max_degree_init: + seed = unassigned_nodes[ + torch.argmax(torch.as_tensor(graph.degree[unassigned_nodes])) + ] + else: + seed = unassigned_nodes[ + torch.multinomial( + torch.as_tensor(graph.degree[unassigned_nodes]), 1 + ) + ] + clusters[seed] = num_clusters + spread_weights = torch.cat( + ( + spread_weights, + torch.zeros( + (1, graph.num_nodes), dtype=torch.double, device=graph.device + ), + ) + ) + unassigned[seed] = False + spread_weights[:, seed] = -1 + inds, weights = graph.adj_weighted(seed) + keep = unassigned[inds] + spread_weights[num_clusters, inds[keep]] += ( + weights[keep] / graph.strength[inds[keep]] + ) + num_clusters += 1 + num_unassigned -= 1 + return clusters + + +def hierarchical_aglomerative_clustering( + graph, method=spread_clustering, levels=None, branch_factors=None +): + """ TODO: docstring for hierarchical_aglomerative_clustering. """ + + if branch_factors is None: + branch_factors = [graph.num_nodes ** (1 / (levels + 1)) for _ in range(levels)] + else: + if not isinstance(branch_factors, Iterable): + branch_factors = [branch_factors] * (levels) + else: + if levels is None: + levels = len(branch_factors) + elif len(branch_factors) != levels: + raise ValueError(f"{levels=} does not match {len(branch_factors)=}") + num_clusters = np.cumprod(branch_factors)[::-1] + clusters = [] + rgraph = graph + for c in num_clusters: + cluster = method(rgraph, int(c)) + rgraph = rgraph.partition_graph(cluster) + clusters.append(cluster) + return clusters + + +class Partition(Sequence): + """ TODO: docstring for Partition. """ + def __init__(self, partition_tensor): + partition_tensor = torch.as_tensor(partition_tensor) + counts = torch.bincount(partition_tensor) + self.num_parts = len(counts) + self.nodes = torch.argsort(partition_tensor) + self.part_index = torch.zeros(self.num_parts + 1, dtype=torch.long) + self.part_index[1:] = torch.cumsum(counts, dim=0) + + def __getitem__(self, item): + return self.nodes[self.part_index[item] : self.part_index[item + 1]] + + def __len__(self): + return self.num_parts diff --git a/l2gv2/network/__init__.py b/l2gv2/network/__init__.py new file mode 100644 index 0000000..9e3bf95 --- /dev/null +++ b/l2gv2/network/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) 2021. Lucas G. S. Jeub +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" TODO: module docstring for network/__init__.py""" +from .npgraph import NPGraph +from .tgraph import TGraph +from .utils import * # TODO: this should be removed diff --git a/l2gv2/network/graph.py b/l2gv2/network/graph.py new file mode 100644 index 0000000..3de91a4 --- /dev/null +++ b/l2gv2/network/graph.py @@ -0,0 +1,299 @@ +# Copyright (c) 2021. Lucas G. S. Jeub +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +"""TODO: module docstring for network/graph.py""" + +from typing import Sequence, Iterable +from abc import abstractmethod +import networkx as nx +import numpy as np + +# pylint: disable=too-many-instance-attributes +# pylint: disable=too-many-public-methods +class Graph: + """ + numpy backed graph class with support for memmapped edge_index + """ + + weights: Sequence + degree: Sequence + device = "cpu" + + @staticmethod + def _convert_input(inp): + return inp + + @classmethod + def from_tg(cls, data): + """ TODO: docstring for from_tg.""" + return cls( + edge_index=data.edge_index, + edge_attr=data.edge_attr, + x=data.x, + y=data.y, + num_nodes=data.num_nodes, + ) + + @classmethod + def from_networkx(cls, nx_graph: nx.Graph, weight=None): + """ TODO: docstring for from_networkx.""" + undir = not nx_graph.is_directed() + if undir: + nx_graph = nx_graph.to_directed(as_view=True) + num_nodes = nx_graph.number_of_nodes() + num_edges = nx_graph.number_of_edges() + edge_index = np.empty((2, num_edges), dtype=np.int64) + weights = [] + for i, (*e, w) in enumerate(nx_graph.edges(data=weight)): + edge_index[:, i] = e + if w is not None: + weights.append(w) + if weights and len(weights) != num_edges: + raise RuntimeError("some edges have missing weight") + + if weight is not None: + weights = np.array(weights) + else: + weights = None + + return cls( + edge_index, weights, num_nodes=num_nodes, ensure_sorted=True, undir=undir + ) + + @abstractmethod + def __init__( + self, + edge_index, + edge_attr=None, + x=None, + y=None, + num_nodes=None, + adj_index=None, + ensure_sorted=False, + undir=None, + nodes=None, + ): + """ + Initialise graph + + Args: + edge_index: edge index such that ``edge_index[0]`` lists the source + and ``edge_index[1]`` the target node for each edge + + edge_attr: optionally provide edge weights + + num_nodes: specify number of nodes (default: ``max(edge_index)+1``) + + ensure_sorted: if ``False``, assume that the ``edge_index`` input is already sorted + + undir: boolean indicating if graph is directed. + If not provided, the ``edge_index`` is checked to determine this value. + """ + self.edge_index = self._convert_input(edge_index) + self.edge_attr = self._convert_input(edge_attr) + self._nodes = self._convert_input(nodes) + self.x = self._convert_input(x) + self.y = self._convert_input(y) + self.num_nodes = num_nodes + if self.num_nodes is not None: + self.num_nodes = int(num_nodes) + self.undir = undir + self.adj_index = self._convert_input(adj_index) + + @property + def weighted(self): + """boolean indicating if graph is weighted""" + return self.edge_attr is not None + + @property + def num_edges(self): + """ TODO: docstring for num_edges.""" + return self.edge_index.shape[1] + + @property + def num_features(self): + """ TODO: docstring for num_features.""" + return 0 if self.x is None else self.x.shape[1] + + @property + def nodes(self): + """ TODO: docstring for nodes.""" + if self._nodes is None: + return range(self.num_nodes) + return self._nodes + + def has_node_labels(self): + """ TODO: docstring for has_node_labels.""" + return self._nodes is not None + + def adj(self, node: int): + """ + list neighbours of node + + Args: + node: source node + + Returns: + neighbours + + """ + return self.edge_index[1][self.adj_index[node] : self.adj_index[node + 1]] + + def adj_weighted(self, node: int): + """ + list neighbours of node and corresponding edge weight + Args: + node: source node + + Returns: + neighbours, weights + + """ + return self.adj(node), self.weights[ + self.adj_index[node] : self.adj_index[node + 1] + ] + + @abstractmethod + def edges(self): + """ + iterator over edges + """ + raise NotImplementedError + + @abstractmethod + def edges_weighted(self): + """ + iterator over weighted edges where each edge is a tuple ``(source, target, weight)`` + """ + raise NotImplementedError + + @abstractmethod + def is_edge(self, source, target): + """ TODO: docstring for is_edge.""" + raise NotImplementedError + + @abstractmethod + def neighbourhood(self, nodes, hops: int = 1): + """ + find the neighbourhood of a set of source nodes + + note that the neighbourhood includes the source nodes themselves + + Args: + nodes: indices of source nodes + hops: number of hops for neighbourhood + + Returns: + neighbourhood + + """ + raise NotImplementedError + + @abstractmethod + def subgraph(self, nodes: Iterable, relabel=False, keep_x=True, keep_y=True): + """ + find induced subgraph for a set of nodes + + Args: + nodes: node indeces + + Returns: + subgraph + + """ + raise NotImplementedError + + @abstractmethod + def connected_component_ids(self): + """ + return connected component ids where ids are sorted in decreasing order by component size + + Returns: + Sequence of node indeces + + """ + raise NotImplementedError + + def nodes_in_lcc(self): + """Iterator over nodes in the largest connected component""" + return (i for i, c in enumerate(self.connected_component_ids()) if c == 0) + + def lcc(self, relabel=False): + """ TODO: docstring for lcc.""" + return self.subgraph(self.nodes_in_lcc(), relabel) + + def to_networkx(self): + """convert graph to NetworkX format""" + if self.undir: + nxgraph = nx.Graph() + else: + nxgraph = nx.DiGraph() + nxgraph.add_nodes_from(range(self.num_nodes)) + if self.weighted: + nxgraph.add_weighted_edges_from(self.edges_weighted()) + else: + nxgraph.add_edges_from(self.edges()) + return nxgraph + + def to(self, graph_cls): + """ TODO: docstring for to.""" + if self.__class__ is graph_cls: + return self + + return graph_cls( + edge_index=self.edge_index, + edge_attr=self.edge_attr, + x=self.x, + y=self.y, + num_nodes=self.num_nodes, + adj_index=self.adj_index, + ensure_sorted=False, + undir=self.undir, + nodes=self._nodes, + ) + + @abstractmethod + def bfs_order(self, start=0): + """ + return nodes in breadth-first-search order + + Args: + start: index of starting node (default: 0) + + Returns: + Sequence of node indeces + + """ + raise NotImplementedError + + @abstractmethod + def partition_graph(self, partition, self_loops=True): + """ TODO: docstring for partition_graph.""" + raise NotImplementedError + + @abstractmethod + def sample_negative_edges(self, num_samples): + """ TODO: docstring for sample_negative_edges.""" + raise NotImplementedError + + def sample_positive_edges(self, num_samples): + """ TODO: docstring for sample_positive_edges.""" + raise NotImplementedError +# pylint: enable=too-many-public-methods +# pylint: enable=too-many-instance-attributes diff --git a/l2gv2/network/npgraph.py b/l2gv2/network/npgraph.py new file mode 100644 index 0000000..1739df4 --- /dev/null +++ b/l2gv2/network/npgraph.py @@ -0,0 +1,540 @@ +# Copyright (c) 2021. Lucas G. S. Jeub +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +"""TODO: module docstring for network/npgraph.py""" + +import json +from pathlib import Path +from random import randrange + +import numpy as np +import torch +import numba +from numba.experimental import jitclass + +from l2gv2 import progress +from .graph import Graph + + +rng = np.random.default_rng() + + +spec = [ + ("edge_index", numba.int64[:, :]), + ("adj_index", numba.int64[:]), + ("degree", numba.int64[:]), +] + + +# pylint: disable=too-many-instance-attributes +class NPGraph(Graph): + """ + numpy backed graph class with support for memmapped edge_index + """ + + @staticmethod + def _convert_input(inp): + if inp is None: + return inp + + if isinstance(inp, torch.Tensor): + return np.asanyarray(inp.cpu()) + + return np.asanyarray(inp) + + @classmethod + def load(cls, folder, mmap_edges=None, mmap_features=None): + """TODO: docstring for load.""" + folder = Path(folder) + kwargs = {} + + kwargs["edge_index"] = np.load(folder / "edge_index.npy", mmap_mode=mmap_edges) + + attr_file = folder / "edge_attr.npy" + if attr_file.is_file(): + kwargs["edge_attr"] = np.load(attr_file, mmap_mode=mmap_edges) + + info_file = folder / "info.json" + if info_file.is_file(): + with open(info_file, encoding="utf-8") as f: + info = json.load(f) + kwargs.update(info) + + feat_file = folder / "node_feat.npy" + if feat_file.is_file(): + kwargs["x"] = np.load(feat_file, mmap_mode=mmap_features) + + label_file = folder / "node_label.npy" + if label_file.is_file(): + kwargs["y"] = np.load(label_file) + + index_file = folder / "adj_index.npy" + if index_file.is_file(): + kwargs["adj_index"] = np.load(index_file) + + return cls(**kwargs) + + def save(self, folder): + """TODO: docstring for save.""" + folder = Path(folder) + np.save(folder / "edge_index.npy", self.edge_index) + + if self.weighted: + np.save(folder / "edge_attr.npy", self.edge_attr) + + np.save(folder / "adj_index.npy", self.adj_index) + + info = {"num_nodes": self.num_nodes, "undir": self.undir} + with open(folder / "info.json", "w", encoding="utf-8") as f: + json.dump(info, f) + + if self.y is not None: + np.save(self.y, folder / "node_label.npy") + + if self.x is not None: + np.save(self.x, folder / "node_feat.npy") + + def __init__(self, *args, ensure_sorted=False, **kwargs): + super().__init__(*args, **kwargs) + + if self.num_nodes is None: + self.num_nodes = np.max(self.edge_index) + 1 + + if ensure_sorted: + if isinstance(self.edge_index, np.memmap): + raise NotImplementedError( + "Sorting for memmapped arrays not yet implemented" + ) + index = np.argsort( + self.edge_index[0] * self.num_nodes + self.edge_index[1] + ) + self.edge_index = self.edge_index[:, index] + if self.edge_attr is not None: + self.edge_attr = self.edge_attr[index] + self._jitgraph = JitGraph(self.edge_index, self.num_nodes, self.adj_index, None) + self.adj_index = self._jitgraph.adj_index + self.degree = self._jitgraph.degree + self.num_nodes = self._jitgraph.num_nodes + + if self.weighted: + self.weights = self.edge_attr + self.strength = np.zeros(self.num_nodes) #: tensor of node strength + np.add.at(self.strength, self.edge_index[0], self.weights) + else: + self.weights = np.broadcast_to( + np.ones(1), (self.num_edges,) + ) # use expand to avoid actually allocating large array + self.strength = self.degree + self.device = "cpu" + + if self.undir is None: + if isinstance(self.edge_index, np.memmap): + raise NotImplementedError( + "Checking directedness for memmapped arrays not yet implemented" + ) + + index = np.argsort( + self.edge_index[1] * self.num_nodes + self.edge_index[0] + ) + edge_reverse = self.edge_index[::-1, index] + self.undir = np.array_equal(self.edge_index, edge_reverse) + if self.weighted: + self.undir = self.undir and np.array_equal( + self.weights, self.weights[index] + ) + + def edges(self): + """ + return list of edges where each edge is a tuple ``(source, target)`` + """ + return ((e[0], e[1]) for e in self.edge_index.T) + + def edges_weighted(self): + """ + return list of edges where each edge is a tuple ``(source, target, weight)`` + """ + return ( + (e[0], e[1], w[0] if w.size > 1 else w) + for e, w in zip(self.edge_index.T, self.weights) + ) + + def is_edge(self, source, target): + return self._jitgraph.is_edge(source, target) + + def neighbourhood(self, nodes, hops: int = 1): + """ + find the neighbourhood of a set of source nodes + + note that the neighbourhood includes the source nodes themselves + + Args: + nodes: indices of source nodes + hops: number of hops for neighbourhood + + Returns: + neighbourhood + + """ + explore = np.ones(self.num_nodes, dtype=np.bool) + explore[nodes] = False + all_nodes = nodes + new_nodes = nodes + for _ in range(hops): + new_nodes = np.concatenate([self.adj(node) for node in new_nodes]) + new_nodes = np.unique(new_nodes[explore[new_nodes]]) + explore[new_nodes] = False + all_nodes = np.concatenate((all_nodes, new_nodes)) + return all_nodes + + def subgraph(self, nodes: torch.Tensor, relabel=False, keep_x=True, keep_y=True): + """ + find induced subgraph for a set of nodes + + Args: + nodes: node indeces + + Returns: + subgraph + + """ + nodes = np.asanyarray(nodes) + edge_index, index = self._jitgraph.subgraph_edges(nodes) + edge_attr = self.edge_attr + if relabel: + node_labels = None + else: + node_labels = [self.nodes[n] for n in nodes] + if self.x is not None and keep_x: + x = self.x[nodes] + else: + x = None + if self.y is not None and keep_y: + y = self.y[nodes] + else: + y = None + return self.__class__( + edge_index=edge_index, + edge_attr=edge_attr[index] if edge_attr is not None else None, + num_nodes=len(nodes), + ensure_sorted=False, + undir=self.undir, + nodes=node_labels, + x=x, + y=y, + ) + + def connected_component_ids(self): + """ + return nodes in breadth-first-search order + + Args: + start: index of starting node (default: 0) + + Returns: + tensor of node indeces + + """ + return self._jitgraph.connected_component_ids() + + def nodes_in_lcc(self): + """List all nodes in the largest connected component""" + return np.flatnonzero(self.connected_component_ids() == 0) + + def bfs_order(self, start=0): + """ + return nodes in breadth-first-search order + + Args: + start: index of starting node (default: 0) + + Returns: + tensor of node indeces + + """ + bfs_list = np.full((self.num_nodes,), -1, dtype=np.int64) + not_visited = np.ones(self.num_nodes, dtype=np.int64) + bfs_list[0] = start + not_visited[start] = False + append_pointer = 1 + i = 0 + restart = 0 + while append_pointer < self.num_nodes: + node = bfs_list[i] + if node < 0: + for node in range(restart, self.num_nodes): + if not_visited[node]: + break + restart = node + bfs_list[i] = node + not_visited[node] = False + append_pointer += 1 + i += 1 + new_nodes = self.adj(node) + new_nodes = new_nodes[not_visited[new_nodes]] + number_new_nodes = len(new_nodes) + not_visited[new_nodes] = False + bfs_list[append_pointer : append_pointer + number_new_nodes] = new_nodes + append_pointer += number_new_nodes + return bfs_list + + def partition_graph(self, partition, self_loops=True): + partition = np.asanyarray(partition) + partition_edges, weights = self._jitgraph.partition_graph_edges( + partition, self_loops + ) + return self.__class__( + edge_index=partition_edges, edge_attr=weights, undir=self.undir + ) + + def sample_negative_edges(self, num_samples): + return self._jitgraph.sample_negative_edges(num_samples) + + def sample_positive_edges(self, num_samples): + index = rng.integers(self.num_edges, size=(num_samples,)) + return self.edge_index[:, index] + + +# pylint: enable=too-many-instance-attributes + + +@numba.njit +def _subgraph_edges(edge_index, adj_index, degs, num_nodes, sources): + max_edges = degs[sources].sum() + subgraph_edge_index = np.empty((2, max_edges), dtype=np.int64) + index = np.empty((max_edges,), dtype=np.int64) + target_index = np.full((num_nodes,), -1, np.int64) + target_index[sources] = np.arange(len(sources)) + count = 0 + + for s, source in enumerate(sources): + for i in range(adj_index[source], adj_index[source + 1]): + t = target_index[edge_index[1, i]] + if t >= 0: + subgraph_edge_index[0, count] = s + subgraph_edge_index[1, count] = t + index[count] = i + count += 1 + return subgraph_edge_index[:, :count], index[:count] + + +@numba.njit +def _memmap_degree(edge_index, num_nodes): + degree = np.zeros(num_nodes, dtype=np.int64) + with numba.objmode: + print("computing degrees") + pbar = progress.reset(edge_index.shape[1]) + for it, source in enumerate(edge_index[0]): + degree[source] += 1 + if it % 1000000 == 0 and it > 0: + with numba.objmode: + progress.update(pbar, 1000000) + with numba.objmode: + progress.close(pbar) + return degree + + +@jitclass( + [ + ("edge_index", numba.int64[:, :]), + ("adj_index", numba.int64[:]), + ("degree", numba.int64[:]), + ("num_nodes", numba.int64), + ] +) +class JitGraph: + """TODO: docstring for JitGraph.""" + + def __init__(self, edge_index, num_nodes=None, adj_index=None, degree=None): + if num_nodes is None: + num_nodes_int = edge_index.max() + 1 + else: + num_nodes_int = num_nodes + + if adj_index is None: + adj_index_ar = np.zeros((num_nodes_int + 1,), dtype=np.int64) + else: + adj_index_ar = adj_index + + if degree is None: + if adj_index is None: + degree = np.zeros((num_nodes_int,), dtype=np.int64) + for s in edge_index[0]: + degree[s] += 1 + adj_index_ar[1:] = degree.cumsum() + else: + degree = adj_index_ar[1:] - adj_index_ar[:-1] + + self.edge_index = edge_index + self.adj_index = adj_index_ar + self.degree = degree + self.num_nodes = num_nodes_int + + def is_edge(self, source, target): + """TODO: docstring for is_edge.""" + if source not in range(self.num_nodes) or target not in range(self.num_nodes): + return False + index = np.searchsorted( + self.edge_index[1, self.adj_index[source] : self.adj_index[source + 1]], + target, + ) + if ( + index < self.degree[source] + and self.edge_index[1, self.adj_index[source] + index] == target + ): + return True + + return False + + def sample_negative_edges(self, num_samples): + """TODO: docstring for sample_negative_edges.""" + i = 0 + sampled_edges = np.empty((2, num_samples), dtype=np.int64) + while i < num_samples: + source = randrange(self.num_nodes) + target = randrange(self.num_nodes) + if not self.is_edge(source, target): + sampled_edges[0, i] = source + sampled_edges[1, i] = target + i += 1 + return sampled_edges + + def adj(self, node): + """TODO: docstring for adj.""" + return self.edge_index[1][self.adj_index[node] : self.adj_index[node + 1]] + + def neighbours(self, nodes): + """TODO: docstring for neighbours.""" + size = self.degree[nodes].sum() + out = np.empty((size,), dtype=np.int64) + it = 0 + for node in nodes: + out[it : it + self.degree[node]] = self.adj(node) + it += self.degree[node] + return np.unique(out) + + def sample_positive_edges(self, num_samples): + """TODO: docstring for sample_positive_edges.""" + index = np.random.randint(self.num_edges, (num_samples,)) + return self.edge_index[:, index] + + def subgraph_edges(self, sources): + """TODO: docstring for subgraph_edges.""" + max_edges = self.degree[sources].sum() + subgraph_edge_index = np.empty((2, max_edges), dtype=np.int64) + index = np.empty((max_edges,), dtype=np.int64) + target_index = np.full((self.num_nodes,), -1, np.int64) + target_index[sources] = np.arange(len(sources)) + count = 0 + + for s, source in enumerate(sources): + for ei in range(self.adj_index[source], self.adj_index[source + 1]): + t = target_index[self.edge_index[1][ei]] + if t >= 0: + subgraph_edge_index[0, count] = s + subgraph_edge_index[1, count] = t + index[count] = ei + count += 1 + return subgraph_edge_index[:, :count], index[:count] + + def subgraph(self, sources): + """TODO: docstring for subgraph.""" + edge_index, _ = self.subgraph_edges(sources) + return JitGraph(edge_index, len(sources), None, None) + + def partition_graph_edges(self, partition, self_loops): + """TODO: docstring for partition_graph_edges.""" + num_edges = self.num_edges + with numba.objmode: + print("finding partition edges") + pbar = progress.reset(num_edges) + num_clusters = partition.max() + 1 + edge_counts = np.zeros((num_clusters, num_clusters), dtype=np.int64) + for i, (source, target) in enumerate(self.edge_index.T): + source = partition[source] + target = partition[target] + if self_loops or (source != target): + edge_counts[source, target] += 1 + if i % 1000000 == 0 and i > 0: + with numba.objmode: + progress.update(pbar, 1000000) + with numba.objmode: + progress.close(pbar) + index = np.nonzero(edge_counts) + partition_edges = np.vstack(index) + weights = np.empty((len(index[0]),), dtype=np.int64) + for it, (i, j) in enumerate(zip(*index)): + weights[it] = edge_counts[i][j] + return partition_edges, weights + + def partition_graph(self, partition, self_loops): + """TODO: docstring for partition_graph.""" + edge_index, _ = self.partition_graph_edges(partition, self_loops) + return JitGraph(edge_index, None, None, None) + + def connected_component_ids(self): + """ + return nodes in breadth-first-search order + + Args: + start: index of starting node (default: 0) + + Returns: + tensor of node indeces + + """ + components = np.full((self.num_nodes,), -1, dtype=np.int64) + not_visited = np.ones(self.num_nodes, dtype=np.bool) + component_id = 0 + components[0] = component_id + not_visited[0] = False + bfs_list = [0] + i = 0 + for _ in range(self.num_nodes): + if bfs_list: + node = bfs_list.pop() + else: + component_id += 1 + for i in range(i, self.num_nodes): + if not_visited[i]: + break + node = i + not_visited[node] = False + components[node] = component_id + new_nodes = self.adj(node) + new_nodes = new_nodes[not_visited[new_nodes]] + not_visited[new_nodes] = False + bfs_list.extend(new_nodes) + + num_components = components.max() + 1 + component_size = np.zeros((num_components,), dtype=np.int64) + for i in components: + component_size[i] += 1 + new_id = np.argsort(component_size)[::-1] + inverse = np.empty_like(new_id) + inverse[new_id] = np.arange(num_components) + return inverse[components] + + def nodes_in_lcc(self): + """List all nodes in the largest connected component""" + return np.flatnonzero(self.connected_component_ids() == 0) + + @property + def num_edges(self): + """TODO: docstring for num_edges.""" + return self.edge_index.shape[1] diff --git a/l2gv2/network/tgraph.py b/l2gv2/network/tgraph.py new file mode 100644 index 0000000..7fbf4d3 --- /dev/null +++ b/l2gv2/network/tgraph.py @@ -0,0 +1,365 @@ +# Copyright (c) 2021. Lucas G. S. Jeub +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +"""TODO: module docstring for network/tgraph.py""" + +import networkx as nx +import torch +import torch_scatter as ts +import torch_geometric as tg + +from .graph import Graph + +# pylint: disable=too-many-instance-attributes +class TGraph(Graph): + """Wrapper class for pytorch-geometric edge_index providing fast adjacency look-up.""" + + @staticmethod + def _convert_input(inp): + if inp is None: + return None + + return torch.as_tensor(inp) + + def __init__(self, *args, ensure_sorted=False, **kwargs): + super().__init__(*args, **kwargs) + + if self.num_nodes is None: + self.num_nodes = int(torch.max(self.edge_index) + 1) #: number of nodes + + if ensure_sorted: + index = torch.argsort( + self.edge_index[0] * self.num_nodes + self.edge_index[1] + ) + self.edge_index = self.edge_index[:, index] + if self.edge_attr is not None: + self.edge_attr = self.edge_attr[index] + + if self.adj_index is None: + self.degree = torch.zeros( + self.num_nodes, dtype=torch.long, device=self.device + ) #: tensor of node degrees + self.degree.index_add_( + 0, + self.edge_index[0], + torch.ones(1, dtype=torch.long, device=self.device).expand( + self.num_edges + ), + ) # use expand to avoid actually allocating large array + self.adj_index = torch.zeros( + self.num_nodes + 1, dtype=torch.long + ) + #: adjacency index such that edges starting at node ``i`` + # are given by ``edge_index[:, adj_index[i]:adj_index[i+1]]`` + self.adj_index[1:] = torch.cumsum(self.degree, 0) + else: + self.degree = self.adj_index[1:] - self.adj_index[:-1] + + if self.weighted: + self.weights = self.edge_attr + self.strength = torch.zeros( + self.num_nodes, device=self.device, dtype=self.weights.dtype + ) #: tensor of node strength + self.strength.index_add_(0, self.edge_index[0], self.weights) + else: + # use expand to avoid actually allocating large array + self.weights = torch.ones(1, device=self.device).expand(self.num_edges) + self.strength = self.degree + + if self.undir is None: + index = torch.argsort( + self.edge_index[1] * self.num_nodes + self.edge_index[0] + ) + self.undir = torch.equal( + self.edge_index, self.edge_index[:, index].flip((0,)) + ) + if self.weighted: + self.undir = self.undir and torch.equal( + self.weights, self.weights[index] + ) + + @property + def device(self): + """device holding graph data""" + return self.edge_index.device + + def edges(self): + """ + return list of edges where each edge is a tuple ``(source, target)`` + """ + return ( + (self.edge_index[0, e].item(), self.edge_index[1, e].item()) + for e in range(self.num_edges) + ) + + def edges_weighted(self): + """ + return list of edges where each edge is a tuple ``(source, target, weight)`` + """ + return ( + ( + self.edge_index[0, e].item(), + self.edge_index[1, e].item(), + self.weights[e].cpu().numpy() + if self.weights.ndim > 1 + else self.weights[e].item(), + ) + for e in range(self.num_edges) + ) + + def is_edge(self, source, target): + index = torch.bucketize( + target, + self.edge_index[1, self.adj_index[source] : self.adj_index[source + 1]], + ) + if ( + index < self.degree[source] + and self.edge_index[1, self.adj_index[source] + index] == target + ): + return True + + return False + + def neighbourhood(self, nodes: torch.Tensor, hops: int = 1): + """ + find the neighbourhood of a set of source nodes + + note that the neighbourhood includes the source nodes themselves + + Args: + nodes: indices of source nodes + hops: number of hops for neighbourhood + + Returns: + neighbourhood + + """ + explore = torch.ones(self.num_nodes, dtype=torch.bool, device=self.device) + explore[nodes] = False + all_nodes = [nodes] + new_nodes = nodes + for _ in range(hops): + new_nodes = torch.cat([self.adj(node) for node in new_nodes]) + new_nodes = torch.unique(new_nodes[explore[new_nodes]]) + explore[new_nodes] = False + all_nodes.append(new_nodes) + return torch.cat(all_nodes) + + def subgraph(self, nodes: torch.Tensor, relabel=False, keep_x=True, keep_y=True): + """ + find induced subgraph for a set of nodes + + Args: + nodes: node indeces + + Returns: + subgraph + + """ + index = torch.cat( + [ + torch.arange( + self.adj_index[node], self.adj_index[node + 1], dtype=torch.long + ) + for node in nodes + ] + ) + node_mask = torch.zeros(self.num_nodes, dtype=torch.bool, device=self.device) + node_mask[nodes] = True + node_ids = torch.zeros(self.num_nodes, dtype=torch.long, device=self.device) + node_ids[nodes] = torch.arange(len(nodes), device=self.device) + index = index[node_mask[self.edge_index[1][index]]] + edge_attr = self.edge_attr + if relabel: + node_labels = None + else: + node_labels = [self.nodes[n] for n in nodes] + + if self.x is not None and keep_x: + x = self.x[nodes, :] + else: + x = None + + if self.y is not None and keep_y: + y = self.y[nodes] + else: + y = None + + return self.__class__( + edge_index=node_ids[self.edge_index[:, index]], + edge_attr=edge_attr[index] if edge_attr is not None else None, + num_nodes=len(nodes), + ensure_sorted=True, + undir=self.undir, + x=x, + y=y, + nodes=node_labels, + ) + + def connected_component_ids(self): + """ Find the (weakly)-connected components. + Component ids are sorted by size, such that id=0 corresponds + to the largest connected component + """ + edge_index = self.edge_index + is_undir = self.undir + last_components = torch.full( + (self.num_nodes,), self.num_nodes, dtype=torch.long, device=self.device + ) + components = torch.arange(self.num_nodes, dtype=torch.long, device=self.device) + while not torch.equal(last_components, components): + last_components[:] = components + components = ts.scatter( + last_components[edge_index[0]], + edge_index[1], + out=components, + reduce="min", + ) + if not is_undir: + components = ts.scatter( + last_components[edge_index[1]], + edge_index[0], + out=components, + reduce="min", + ) + _, inverse, component_size = torch.unique( + components, return_counts=True, return_inverse=True + ) + new_id = torch.argsort(component_size, descending=True) + return new_id[inverse] + + def nodes_in_lcc(self): + """List all nodes in the largest connected component""" + return torch.nonzero(self.connected_component_ids() == 0).flatten() + + def to_networkx(self): + """convert graph to NetworkX format""" + if self.undir: + nxgraph = nx.Graph() + else: + nxgraph = nx.DiGraph() + nxgraph.add_nodes_from(range(self.num_nodes)) + if self.x is not None: + for i in range(self.num_nodes): + nxgraph.nodes[i]["x"] = self.x[i, :] + if self.y is not None: + for i in range(self.num_nodes): + nxgraph.nodes[i]["y"] = self.y[i] + if self.weighted: + nxgraph.add_weighted_edges_from(self.edges_weighted()) + else: + nxgraph.add_edges_from(self.edges()) + return nxgraph + + def to(self, *args, graph_cls=None, **kwargs): + """ + Convert to different graph type or move to device + + Args: + graph_cls: convert to graph class + device: convert to device + + Can only specify one argument. If positional, type of move is determined automatically. + + """ + if args: + if graph_cls is not None: + raise ValueError( + "Both positional and graph_cls keyword argument specified." + ) + if len(args) == 1: + arg = args[0] + if isinstance(arg, type) and issubclass(arg, Graph): + graph_cls = arg + if kwargs: + raise ValueError( + "Cannot specify additional keyword arguments " + "when converting between graph classes." + ) + + if graph_cls is not None: + return super().to(graph_cls) + + for key, value in self.__dict__.items(): + if isinstance(value, torch.Tensor): + self.__dict__[key] = value.to(*args, **kwargs) + return self + + def bfs_order(self, start=0): + """ + return nodes in breadth-first-search order + + Args: + start: index of starting node (default: 0) + + Returns: + tensor of node indeces + + """ + bfs_list = torch.full( + (self.num_nodes,), -1, dtype=torch.long, device=self.device + ) + not_visited = torch.ones(self.num_nodes, dtype=torch.bool, device=self.device) + bfs_list[0] = start + not_visited[start] = False + append_pointer = 1 + i = 0 + while append_pointer < self.num_nodes: + node = bfs_list[i] + if node < 0: + node = torch.nonzero(not_visited)[0] + bfs_list[i] = node + not_visited[node] = False + append_pointer += 1 + i += 1 + new_nodes = self.adj(node) + new_nodes = new_nodes[not_visited[new_nodes]] + number_new_nodes = len(new_nodes) + not_visited[new_nodes] = False + bfs_list[append_pointer : append_pointer + number_new_nodes] = new_nodes + append_pointer += number_new_nodes + return bfs_list + + def partition_graph(self, partition, self_loops=True): + num_clusters = torch.max(partition) + 1 + pe_index = ( + partition[self.edge_index[0]] * num_clusters + partition[self.edge_index[1]] + ) + partition_edges, weights = torch.unique(pe_index, return_counts=True) + partition_edges = torch.stack( + (partition_edges // num_clusters, partition_edges % num_clusters), dim=0 + ) + if not self_loops: + valid = partition_edges[0] != partition_edges[1] + partition_edges = partition_edges[:, valid] + weights = weights[valid] + return self.__class__( + edge_index=partition_edges, + edge_attr=weights, + num_nodes=num_clusters, + undir=self.undir, + ) + + def sample_negative_edges(self, num_samples): + return tg.utils.negative_sampling(self.edge_index, self.num_nodes, num_samples) + + def sample_positive_edges(self, num_samples): + index = torch.randint(self.num_edges, (num_samples,), dtype=torch.long) + return self.edge_index[:, index] +# pylint: enable=too-many-instance-attributes diff --git a/l2gv2/network/utils.py b/l2gv2/network/utils.py new file mode 100644 index 0000000..0fe94db --- /dev/null +++ b/l2gv2/network/utils.py @@ -0,0 +1,203 @@ +"""Graph data handling""" + +# Copyright (c) 2021. Lucas G. S. Jeub +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch +import numpy as np +import numba +from numba.experimental import jitclass + +from local2global_embedding.network import NPGraph, TGraph +from .graph import Graph + + +@jitclass +class UnionFind: + """Union-find data structure. + + Each unionFind instance X maintains a family of disjoint sets of + hashable objects, supporting the following two methods: + + - X[item] returns a name for the set containing the given item. + Each set is named by an arbitrarily-chosen one of its members; as + long as the set remains unchanged it will keep the same name. If + the item is not yet part of a set in X, a new singleton set is + created for it. + + - X.union(item1, item2, ...) merges the sets containing each item + into a single larger set. If any item is not yet part of a set + in X, it is added to X as one of the members of the merged set. + + Union-find data structure. Based on Josiah Carlson's code, + https://code.activestate.com/recipes/215912/ + with significant additional changes by D. Eppstein. + http://www.ics.uci.edu/~eppstein/PADS/UnionFind.py + + """ + + parents: numba.int64[:] + weights: numba.int64[:] + + def __init__(self, size): + """Create a new empty union-find structure. + + If *elements* is an iterable, this structure will be initialized + with the discrete partition on the given set of elements. + + """ + self.parents = np.arange(size, dtype=np.int64) + self.weights = np.ones(size, dtype=np.int64) + + def find(self, i): + """Find and return the name of the set containing the object.""" + + # find path of objects leading to the root + path = [i] + root = self.parents[i] + while root != path[-1]: + path.append(root) + root = self.parents[root] + + # compress the path and return + for ancestor in path: + self.parents[ancestor] = root + return root + + def union(self, i, j): + """Find the sets containing the objects and merge them all.""" + # Find the heaviest root according to its weight. + roots = (self.find(i), self.find(j)) + if self.weights[roots[0]] < self.weights[roots[1]]: + # heaviest root first + roots = roots[::-1] + + self.weights[roots[0]] += self.weights[roots[1]] + self.parents[roots[1]] = roots[0] + + +def conductance(graph: Graph, source, target=None): + """ + compute conductance between source and target nodes + + Args: + graph: input graph + + source: set of source nodes + + target: set of target nodes (if ``target=None``, + consider all nodes that are not in ``source`` as target) + + Returns: + conductance + + """ + if target is None: + target_mask = torch.ones(graph.num_nodes, dtype=torch.bool, device=graph.device) + target_mask[source] = False + else: + target_mask = torch.zeros(graph.num_nodes, dtype=torch.bool) + target_mask[target] = True + out = torch.cat([graph.adj(node) for node in source]) + cond = torch.sum(target_mask[out]).float() + s_deg = graph.degree[source].sum() + t_deg = graph.num_edges - s_deg if target is None else graph.degree[target].sum() + cond /= torch.minimum(s_deg, t_deg) + return cond + + +def spanning_tree(graph: TGraph, maximise=False): + """Implements Kruskal's algorithm for finding minimum or maximum spanning tree. + + Args: + graph: input graph + maximise: if ``True``, find maximum spanning tree (default: ``False``) + + Returns: + spanning tree + """ + edge_mask = spanning_tree_mask(graph, maximise) + + edge_index = graph.edge_index[:, edge_mask] + if graph.edge_attr is not None: + weights = graph.edge_attr[edge_mask] + else: + weights = None + return TGraph( + edge_index=edge_index, + edge_attr=weights, + num_nodes=graph.num_nodes, + ensure_sorted=False, + ) + + +def spanning_tree_mask(graph: Graph, maximise=False): + """Return an edge mask for minimum or maximum spanning tree edges. + + Args: + graph: input graph + maximise: if ``True``, find maximum spanning tree (default: ``False``) + """ + + convert_to_tensor = isinstance(graph, TGraph) + graph = graph.to(NPGraph) + + # find positions of reverse edges + if graph.undir: + reverse_edge_index = np.argsort( + graph.edge_index[1] * graph.num_nodes + graph.edge_index[0] + ) + forward_edge_index = np.flatnonzero(graph.edge_index[0] < graph.edge_index[1]) + edges = graph.edge_index[:, forward_edge_index] + weights = graph.weights[forward_edge_index] + reverse_edge_index = reverse_edge_index[forward_edge_index] + else: + edges = graph.edge_index + forward_edge_index = np.arange(edges.shape[1]) + weights = graph.weights + reverse_edge_index = None + + index = np.argsort(weights) + if maximise: + index = index[::-1] + + edge_mask = np.zeros(graph.num_edges, dtype=bool) + edge_mask = _spanning_tree_mask( + edge_mask, edges, index, graph.num_nodes, forward_edge_index, reverse_edge_index + ) + if convert_to_tensor: + edge_mask = torch.as_tensor(edge_mask) + return edge_mask + + +@numba.njit +def _spanning_tree_mask( + edge_mask, edges, index, num_nodes, forward_edge_index, reverse_edge_index +): + subtrees = UnionFind(num_nodes) + for i in index: + u = edges[0, i] + v = edges[1, i] + if subtrees.find(u) != subtrees.find(v): + edge_mask[forward_edge_index[i]] = True + if reverse_edge_index is not None: + edge_mask[reverse_edge_index[i]] = True + subtrees.union(u, v) + return edge_mask diff --git a/l2gv2/progress.py b/l2gv2/progress.py new file mode 100644 index 0000000..604d4fd --- /dev/null +++ b/l2gv2/progress.py @@ -0,0 +1,37 @@ +# Copyright (c) 2021. Lucas G. S. Jeub +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" TODO: module docstring for progress. """ +from tqdm.auto import tqdm + + +def reset(total): + """ TODO: docstring for reset""" + return tqdm(total=total) + + +def update(pbar, iterations): + """ TODO: docstring for update""" + pbar.update(iterations) + + +def close(pbar): + """ TODO: docstring for close""" + pbar.update(pbar.total - pbar.n) + pbar.close() diff --git a/l2gv2/sparsify.py b/l2gv2/sparsify.py new file mode 100644 index 0000000..b468a50 --- /dev/null +++ b/l2gv2/sparsify.py @@ -0,0 +1,479 @@ +"""Graph sparsification""" +# pylint: disable=invalid-name + +import math +import warnings + +import numpy as np +import scipy as sc +import torch +import numba +from tqdm.auto import tqdm + +from l2gv2.network import TGraph, spanning_tree_mask, spanning_tree +from l2gv2.clustering import Partition + + +rg = np.random.default_rng() + + +def _gumbel_topk(weights, k, log_weights=False): + """ + sampling without replacement from potentially large set of values + + see arXiv:1903.06059v2 + + Args: + weights: sampling weights (not necessarily normalised) + + Returns: + sampled indices + """ + if k >= len(weights): + return torch.arange(len(weights)) + + if not log_weights: + weights = torch.log(weights) + + dist = torch.distributions.Gumbel(0, 1) + + perturbed = weights + dist.sample(weights.shape) + return torch.topk(perturbed, k, sorted=False)[1] + + +def _sample_edges(graph, n_desired_edges, ensure_connected=True): + if ensure_connected: + edge_mask = spanning_tree_mask(graph, maximise=True) + n_desired_edges -= edge_mask.sum() + unselected_edges = edge_mask.logical_not().nonzero().flatten() + else: + edge_mask = torch.zeros(graph.num_edges, dtype=torch.bool, device=graph.device) + unselected_edges = torch.arange(graph.num_edges, device=graph.device) + if n_desired_edges > 0: # check whether we have sufficiently many edges already + unselected_edge_index = graph.edge_index[:, unselected_edges] + reversed_index = torch.argsort( + unselected_edge_index[1] * graph.num_nodes + unselected_edge_index[0] + ) + forward_unselected = unselected_edges[ + unselected_edge_index[0] < unselected_edge_index[1] + ] + reverse_unselected = unselected_edges[ + reversed_index[unselected_edge_index[0] < unselected_edge_index[1]] + ] + index = _gumbel_topk(graph.weights[forward_unselected], n_desired_edges // 2) + edge_mask[forward_unselected[index]] = True + edge_mask[reverse_unselected[index]] = True + return edge_mask + + +@numba.njit +def _multi_arange(start, stop): + count = np.sum(stop - start) + out = np.empty((count,), dtype=np.int64) + i = 0 + for s, t in zip(start, stop): + out[i : i + (t - s)] = np.arange(s, t) + i += t - s + return out + + +def resistance_sparsify( + graph: TGraph, target_mean_degree, ensure_connected=True, epsilon=1e-2 +): + """ + Sparsify a graph to have a target mean degree using effective resistance based sampling + + + Args: + graph: input graph + target_mean_degree: desired mean degree after sparsification + ensure_connected: if ``True``, first add edges of a maximum spanning tree + based on the resistance weights to ensure that the + sparsified graph remains connected if the input graph is connected + epsilon: tolerance for effective resistance computation + + Returns: + sparsified graph + + This algorithm is based on the method of + + D. A. Spielman and N. Srivastava. + “Graph sparsification by effective resistances”. + SIAM Journal on Computing 40.6 (2011), pp. 1913–1926. + + However, a fixed number of edges are sampled without replacement, + and optionally a maximum spanning tree is kept + to ensure the connectedness of the sparsified graph. + + """ + n_desired_edges = ( + int(target_mean_degree * graph.num_nodes / 2) * 2 + ) # round down to an even number of edges + if n_desired_edges >= graph.num_edges: + # graph is already sufficiently sparse + return graph + + rgraph = resistance_weighted_graph(graph, epsilon=epsilon) + + edge_mask = _sample_edges(rgraph, n_desired_edges, ensure_connected) + edge_index = graph.edge_index[:, edge_mask] + edge_attr = None if graph.edge_attr is None else graph.edge_attr[edge_mask] + return TGraph( + edge_index=edge_index, + edge_attr=edge_attr, + num_nodes=graph.num_nodes, + ensure_sorted=False, + undir=graph.undir, + ) + + +def conductance_weighted_graph(graph: TGraph): + """ TODO: docstring for conductance_weighted_graph """ + weights = graph.weights / torch.minimum( + graph.strength[graph.edge_index[0]], graph.strength[graph.edge_index[1]] + ) + return TGraph( + edge_index=graph.edge_index, + edge_attr=weights, + num_nodes=graph.num_nodes, + adj_index=graph.adj_index, + ensure_sorted=False, + undir=graph.undir, + ) + + +def resistance_weighted_graph(graph: TGraph, **args): + """ + modify the edge weights of a graph by multiplying by their effective resistance + + Args: + graph: input graph + epsilon: tolerance for effective resistance computation (default: ``1e-2``) + + Returns: + copy of input graph with reweighted edges + """ + resistances = effective_resistances(graph, **args) + if graph.edge_attr is None: + edge_attr = resistances + else: + edge_attr = graph.edge_attr * resistances + return TGraph( + graph.edge_index, + edge_attr, + num_nodes=graph.num_nodes, + ensure_sorted=False, + undir=graph.undir, + adj_index=graph.adj_index, + ) + + +def effective_resistances(graph: TGraph, **args): + """ + compute effective resistances + + Args: + graph: input graph + epsilon: tolerance for effective resistance computation (default: ``1e-2``) + + Returns: + effective resistance for each edge + """ + Z = _compute_Z(graph, **args) + Z = torch.from_numpy(Z) + resistances = ( + torch.pairwise_distance(Z[graph.edge_index[0], :], Z[graph.edge_index[1], :]) + ** 2 + ) + return resistances + + +def _edge_node_incidence_matrix(graph: TGraph): + indices = np.empty(2 * graph.num_edges, dtype=int) + values = np.empty(2 * graph.num_edges, dtype=int) + indptr = 2 * np.arange(graph.num_edges + 1, dtype=np.int64) + indices[::2] = graph.edge_index[0] + indices[1::2] = graph.edge_index[1] + values[::2] = 1 + values[1::2] = -1 + + return sc.sparse.csr_matrix( + (values, indices, indptr), shape=(graph.num_edges, graph.num_nodes) + ) + + +def _edge_weight_matrix(graph: TGraph): + weight = graph.weights.cpu().numpy() + W = sc.sparse.dia_matrix((np.sqrt(weight), 0), shape=(len(weight), len(weight))) + return W + + +def _compute_Z(graph: TGraph, epsilon=10.0**-2.0): + W = _edge_weight_matrix(graph) + B = _edge_node_incidence_matrix(graph) + Y = W.dot(B) + L = Y.transpose().dot(Y) + + n = graph.num_nodes + m = graph.num_edges + k = math.floor(24.0 * math.log(n) / (epsilon**2.0)) + delta = ( + epsilon + / 3.0 + * math.sqrt( + (2.0 * (1.0 - epsilon) * min(W.diagonal())) + / ((1.0 + epsilon) * (n**3.0) * max(W.diagonal())) + ) + ) + + LU = sc.sparse.linalg.spilu(L + epsilon * sc.sparse.eye(n)) + P = sc.sparse.linalg.LinearOperator((n, n), matvec=LU.solve) + Z = np.zeros((n, min(m, k))) + + for i in range(Z.shape[1]): + if k < m: + q = (2 * np.random.randint(0, 2, size=(1, m)) - 1) / math.sqrt(k) + y = q * B + y = y.transpose() + else: + y = Y.getrow(i).transpose().toarray() + + # TODO: fix tol not an argument: + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.lgmres.html + Z[:, i], flag = sc.sparse.linalg.lgmres(L, y, M=P, tol=delta) + + if flag > 0: + warnings.warn(f"BiCGstab not converged after {flag} iterations") + print(Z[:, i]) + + if flag < 0: + warnings.warn(f"BiCGstab error {flag}") + print(Z[:, i]) + + return Z + + +def relaxed_spanning_tree(graph: TGraph, maximise=False, gamma=1): + r"""compute relaxed minimum or maximum spanning tree + + This implements the relaxed minimum spanning tree algorithm of + + M. Beguerisse-Díaz, B. Vangelov, and M. Barahona. + “Finding role communities in directed networks using Role-Based Similarity, + Markov Stability and the Relaxed Minimum Spanning Tree”. + In: 2013 IEEE Global Conference on Signal and Information Processing (GlobalSIP). + IEEE, 2013, pp. 937–940. isbn: 978-1-4799-0248-4. + + Args: + graph: input graph + maximise: if ``True`` start with maximum spanning tree + gamma: :math:`\gamma` value for adding edges + """ + mst = spanning_tree(graph, maximise=maximise) + rmst_edges = [mst.edge_index] + rmst_weights = [mst.edge_attr] + if maximise: + reduce_fun = torch.minimum + d = torch.tensor( + [torch.max(graph.adj_weighted(node)[1]) for node in range(graph.num_nodes)], + device=graph.device, + ) + else: + reduce_fun = torch.maximum + d = torch.tensor( + [torch.min(graph.adj_weighted(node)[1]) for node in range(graph.num_nodes)], + device=graph.device, + ) + target_mask = torch.full( + (graph.num_nodes,), -1, dtype=torch.long, device=graph.device + ) + for i in range(graph.num_nodes): + neighbours, weights = graph.adj_weighted(i) + # provide indices into neighbours so we can look up weights easily + target_mask[neighbours] = torch.arange(neighbours.numel()) + # breadth-first search over mst to find mst path weights (note mst-edges are already added) + mst_neighbours, mst_weights = mst.adj_weighted(i) + target_mask[mst_neighbours] = -1 + not_visited = torch.ones(graph.num_nodes, dtype=torch.bool, device=graph.device) + not_visited[mst_neighbours] = False + not_visited[i] = False + while torch.any(target_mask[neighbours] >= 0): + next_neighbours = [] + next_weights = [] + for node, weight in zip(mst_neighbours, mst_weights): + n, w = mst.adj_weighted(node) + new = not_visited[n] + n = n[new] + w = w[new] + not_visited[n] = False + next_neighbours.append(n) + next_weights.append(reduce_fun(weight, w)) + mst_neighbours = torch.cat(next_neighbours) + mst_weights = torch.cat(next_weights) + index = target_mask[mst_neighbours] + selected = mst_neighbours[index >= 0] + target_mask[selected] = -1 + selected_w = mst_weights[index >= 0] + index = index[index >= 0] + if maximise: + add = selected_w - gamma * (d[i] + d[selected]) < weights[index] + else: + add = selected_w + gamma * (d[i] + d[selected]) > weights[index] + rmst_edges.append( + torch.stack( + ( + torch.full((add.sum().item(),), i, dtype=torch.long), + selected[add], + ), + dim=0, + ) + ) + rmst_weights.append(weights[index[add]]) + edge_index = torch.cat(rmst_edges, dim=1) + edge_attr = torch.cat(rmst_weights) + return TGraph( + edge_index, edge_attr, graph.num_nodes, ensure_sorted=True, undir=graph.undir + ) + + +def edge_sampling_sparsify(graph: TGraph, target_degree, ensure_connected=True): + """ TODO: docstring for edge_sampling_sparsify """ + n_desired_edges = ( + int(target_degree * graph.num_nodes / 2) * 2 + ) # round down to an even number of edges + if n_desired_edges >= graph.num_edges: + # graph is already sufficiently sparse + return graph + + weights = graph.weights / torch.minimum( + graph.strength[graph.edge_index[0]], graph.strength[graph.edge_index[1]] + ) + cgraph = TGraph( + graph.edge_index, + edge_attr=weights, + adj_index=graph.adj_index, + num_nodes=graph.num_nodes, + ensure_sorted=False, + undir=graph.undir, + ) # convert weights to conductance value + edge_mask = _sample_edges(cgraph, n_desired_edges, ensure_connected) + edge_attr = graph.edge_attr[edge_mask] if graph.edge_attr is not None else None + return TGraph( + edge_index=graph.edge_index[:, edge_mask], + edge_attr=edge_attr, + num_nodes=graph.num_nodes, + ensure_sorted=False, + undir=graph.undir, + ) + + +def nearest_neighbor_sparsify(graph: TGraph, target_degree, ensure_connected=True): + """ TODO: docstring for nearest_neighbor_sparsify """ + if ensure_connected: + edge_mask = spanning_tree_mask(graph, maximise=True) + else: + edge_mask = torch.zeros( + (graph.num_nodes,), dtype=torch.bool, device=graph.device + ) + index = 0 + for n in range(graph.num_nodes): + count = graph.adj_index[n + 1] - graph.adj_index[n] + if count > target_degree: + neighbour_index = ( + graph.adj_index[n] + + torch.topk( + graph.weights[graph.adj_index[n] : graph.adj_index[n + 1]], + target_degree, + ).indices + ) + else: + neighbour_index = torch.arange( + graph.adj_index[n], + graph.adj_index[n + 1], + dtype=torch.long, + device=graph.device, + ) + edge_mask[neighbour_index] = True + index += len(neighbour_index) + reverse = torch.argsort(graph.edge_index[1] * graph.num_nodes + graph.edge_index[0]) + edge_mask *= edge_mask[reverse] # only keep edges that exist in both directions + edge_attr = graph.edge_attr[edge_mask] if graph.edge_attr is not None else None + return TGraph( + edge_index=graph.edge_index[:, edge_mask], + edge_attr=edge_attr, + num_nodes=graph.num_nodes, + ensure_sorted=False, + undir=True, + ) + + +def hierarchical_sparsify( + graph: TGraph, + clusters, + target_level_degree, + ensure_connected=True, + sparsifier=edge_sampling_sparsify, +): + """ TODO: docstring for hierarchical_sparsify """ + rgraph = graph + edge_mask = torch.zeros(graph.num_edges, dtype=torch.bool, device=graph.device) + node_map = np.array(graph.nodes) + reverse_index = ( + torch.argsort(graph.edge_index[1] * graph.num_nodes + graph.edge_index[0]) + .cpu() + .numpy() + ) + edges = graph.edge_index.cpu().numpy() + final_num_clusters = clusters[-1].max() + 1 + if final_num_clusters > 1: + clusters.append( + torch.zeros(final_num_clusters, dtype=torch.long, device=graph.device) + ) + for cluster in clusters: + expanded_cluster = cluster[node_map] + parts = Partition(cluster) + expanded_parts = Partition(expanded_cluster) + for p, ep in tqdm( + zip(parts, expanded_parts), total=len(parts), desc="sparsifying clusters" + ): + sgraph = sparsifier( + rgraph.subgraph(p), target_level_degree, ensure_connected + ) + s_edges = p[sgraph.edge_index] + s_edges = s_edges[0] * rgraph.num_nodes + s_edges[1] + s_edges = s_edges.cpu().numpy() + index = _multi_arange( + graph.adj_index[ep].cpu().numpy(), graph.adj_index[ep + 1].cpu().numpy() + ) + index = index[edges[0, index] < edges[1, index]] # only forward direction + mapped_edges = node_map[edges[:, index]] + mapped_edges = mapped_edges[0] * rgraph.num_nodes + mapped_edges[1] + + valid = np.flatnonzero(np.in1d(mapped_edges, s_edges)) + mapped_edges = mapped_edges[valid] + index = index[valid] + u_vals, edge_index = np.unique(mapped_edges, return_inverse=True) + if len(u_vals) < len(valid): + edge_partition = Partition(edge_index) + for e_part in edge_partition: + if len(e_part) > int(target_level_degree): + r = _gumbel_topk(graph.weights[index], int(target_level_degree)) + else: + r = e_part + edge_mask[index[r]] = True + edge_mask[reverse_index[index[r]]] = True + + else: + edge_mask[index] = True + edge_mask[reverse_index[index]] = True + + rgraph = rgraph.partition_graph(cluster, self_loops=False) + node_map = expanded_cluster.cpu().numpy() + edge_attr = graph.edge_attr[edge_mask] if graph.edge_attr is not None else None + return TGraph( + edge_index=graph.edge_index[:, edge_mask], + edge_attr=edge_attr, + num_nodes=graph.num_nodes, + ensure_sorted=False, + undir=graph.undir, + ) diff --git a/l2gv2/utils.py b/l2gv2/utils.py new file mode 100644 index 0000000..373ebed --- /dev/null +++ b/l2gv2/utils.py @@ -0,0 +1,157 @@ +# Copyright (c) 2021. Lucas G. S. Jeub +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +"""TODO: module docstring for utils.py""" + +from tempfile import TemporaryFile +from time import perf_counter +import torch +import torch.nn + + +def speye(n, dtype=torch.float): + """identity matrix of dimension n as sparse_coo_tensor.""" + return torch.sparse_coo_tensor( + torch.tile(torch.arange(n, dtype=torch.long), (2, 1)), + torch.ones(n, dtype=dtype), + (n, n), + ) + + +def get_device(model: torch.nn.Module): + """ TODO: docstring for get_device.""" + return next(model.parameters()).device + + +def set_device(device): + """ TODO: docstring for set_device.""" + if device is None: + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + else: + device = torch.device(device) + return device + + +class EarlyStopping: + """ + Context manager for early stopping + """ + + def __init__(self, patience, delta=0): + """ + Initialise early stopping context manager + + Args: + patience: wait ``patience`` number of epochs without loss improvement before stopping + delta: minimum improvement to consider significant (default: 0) + """ + self.patience = patience + self.delta = delta + self.best_loss = float("inf") + self.count = 0 + self._file = TemporaryFile() + + def __enter__(self): + self.best_loss = float("inf") + self.count = 0 + if self._file.closed: + self._file = TemporaryFile() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._file.close() + + def _save_model(self, model): + self._file.seek(0) + torch.save(model.state_dict(), self._file) + + def _load_model(self, model: torch.nn.Module): + self._file.seek(0) + model.load_state_dict(torch.load(self._file)) + + def __call__(self, loss, model): + """ + check stopping criterion and save or restore model state as appropriate + + Args: + loss: loss value for stopping + model: + + Returns: + ``True`` if training should be stopped, ``False`` otherwise + """ + loss = float( + loss + ) # make sure no tensors used here to avoid propagating gradients + if loss >= self.best_loss - self.delta: + self.count += 1 + else: + self.count = 0 + + if loss < self.best_loss: + self.best_loss = loss + self._save_model(model) + if self.count > self.patience: + self._load_model(model) + return True + + return False + + +class Timer: + """ + Context manager for accumulating execution time + + Adds the time taken within block to a running total. + + """ + + def __init__(self): + self.total = 0.0 + self.tic = None + + def __enter__(self): + self.tic = perf_counter() + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + self.total += perf_counter() - self.tic + + +def flatten(l, ltypes=(list, tuple)): + """ TODO: docstring for flatten.""" + if isinstance(l, ltypes): + ltype = type(l) + l = list(l) + i = 0 + while i < len(l): + while isinstance(l[i], ltypes): + if not l[i]: + l.pop(i) + i -= 1 + break + + l[i : i + 1] = l[i] + i += 1 + return ltype(l) + + return l