Skip to content

Commit

Permalink
Fix linting.
Browse files Browse the repository at this point in the history
  • Loading branch information
mihaeladuta committed Nov 22, 2024
1 parent d0ba9d7 commit 33dc63c
Show file tree
Hide file tree
Showing 4 changed files with 357 additions and 178 deletions.
103 changes: 72 additions & 31 deletions l2gv2/network/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,42 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from typing import Sequence, Collection, Iterable
"""TODO: module docstring for network/graph.py"""

from typing import Sequence, Iterable
from abc import abstractmethod
import networkx as nx
from abc import ABC, abstractmethod
import numpy as np


# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-public-methods
class Graph:
"""
numpy backed graph class with support for memmapped edge_index
"""

weights: Sequence
degree: Sequence
device = 'cpu'
device = "cpu"

@staticmethod
def _convert_input(input):
return input
def _convert_input(inp):
return inp

@classmethod
def from_tg(cls, data):
return cls(edge_index=data.edge_index,
edge_attr=data.edge_attr,
x=data.x,
y=data.y,
num_nodes=data.num_nodes)
""" TODO: docstring for from_tg."""
return cls(
edge_index=data.edge_index,
edge_attr=data.edge_attr,
x=data.x,
y=data.y,
num_nodes=data.num_nodes,
)

@classmethod
def from_networkx(cls, nx_graph: nx.Graph, weight=None):
""" TODO: docstring for from_networkx."""
undir = not nx_graph.is_directed()
if undir:
nx_graph = nx_graph.to_directed(as_view=True)
Expand All @@ -57,27 +65,45 @@ def from_networkx(cls, nx_graph: nx.Graph, weight=None):
if w is not None:
weights.append(w)
if weights and len(weights) != num_edges:
raise RuntimeError('some edges have missing weight')
raise RuntimeError("some edges have missing weight")

if weight is not None:
weights = np.array(weights)
else:
weights = None

return cls(edge_index, weights, num_nodes=num_nodes, ensure_sorted=True, undir=undir)
return cls(
edge_index, weights, num_nodes=num_nodes, ensure_sorted=True, undir=undir
)

@abstractmethod
def __init__(self, edge_index, edge_attr=None, x=None, y=None, num_nodes=None, adj_index=None,
ensure_sorted=False, undir=None, nodes=None):
def __init__(
self,
edge_index,
edge_attr=None,
x=None,
y=None,
num_nodes=None,
adj_index=None,
ensure_sorted=False,
undir=None,
nodes=None,
):
"""
Initialise graph
Args:
edge_index: edge index such that ``edge_index[0]`` lists the source and ``edge_index[1]`` the target node for each edge
edge_index: edge index such that ``edge_index[0]`` lists the source
and ``edge_index[1]`` the target node for each edge
edge_attr: optionally provide edge weights
num_nodes: specify number of nodes (default: ``max(edge_index)+1``)
ensure_sorted: if ``False``, assume that the ``edge_index`` input is already sorted
undir: boolean indicating if graph is directed. If not provided, the ``edge_index`` is checked to determine this value.
undir: boolean indicating if graph is directed.
If not provided, the ``edge_index`` is checked to determine this value.
"""
self.edge_index = self._convert_input(edge_index)
self.edge_attr = self._convert_input(edge_attr)
Expand All @@ -97,20 +123,23 @@ def weighted(self):

@property
def num_edges(self):
""" TODO: docstring for num_edges."""
return self.edge_index.shape[1]

@property
def num_features(self):
""" TODO: docstring for num_features."""
return 0 if self.x is None else self.x.shape[1]

@property
def nodes(self):
""" TODO: docstring for nodes."""
if self._nodes is None:
return range(self.num_nodes)
else:
return self._nodes
return self._nodes

def has_node_labels(self):
""" TODO: docstring for has_node_labels."""
return self._nodes is not None

def adj(self, node: int):
Expand All @@ -124,7 +153,7 @@ def adj(self, node: int):
neighbours
"""
return self.edge_index[1][self.adj_index[node]:self.adj_index[node + 1]]
return self.edge_index[1][self.adj_index[node] : self.adj_index[node + 1]]

def adj_weighted(self, node: int):
"""
Expand All @@ -136,7 +165,9 @@ def adj_weighted(self, node: int):
neighbours, weights
"""
return self.adj(node), self.weights[self.adj_index[node]:self.adj_index[node + 1]]
return self.adj(node), self.weights[
self.adj_index[node] : self.adj_index[node + 1]
]

@abstractmethod
def edges(self):
Expand All @@ -154,6 +185,7 @@ def edges_weighted(self):

@abstractmethod
def is_edge(self, source, target):
""" TODO: docstring for is_edge."""
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -203,6 +235,7 @@ def nodes_in_lcc(self):
return (i for i, c in enumerate(self.connected_component_ids()) if c == 0)

def lcc(self, relabel=False):
""" TODO: docstring for lcc."""
return self.subgraph(self.nodes_in_lcc(), relabel)

def to_networkx(self):
Expand All @@ -219,18 +252,21 @@ def to_networkx(self):
return nxgraph

def to(self, graph_cls):
""" TODO: docstring for to."""
if self.__class__ is graph_cls:
return self
else:
return graph_cls(edge_index=self.edge_index,
edge_attr=self.edge_attr,
x=self.x,
y=self.y,
num_nodes=self.num_nodes,
adj_index=self.adj_index,
ensure_sorted=False,
undir=self.undir,
nodes=self._nodes)

return graph_cls(
edge_index=self.edge_index,
edge_attr=self.edge_attr,
x=self.x,
y=self.y,
num_nodes=self.num_nodes,
adj_index=self.adj_index,
ensure_sorted=False,
undir=self.undir,
nodes=self._nodes,
)

@abstractmethod
def bfs_order(self, start=0):
Expand All @@ -248,11 +284,16 @@ def bfs_order(self, start=0):

@abstractmethod
def partition_graph(self, partition, self_loops=True):
""" TODO: docstring for partition_graph."""
raise NotImplementedError

@abstractmethod
def sample_negative_edges(self, num_samples):
""" TODO: docstring for sample_negative_edges."""
raise NotImplementedError

def sample_positive_edges(self, num_samples):
""" TODO: docstring for sample_positive_edges."""
raise NotImplementedError
# pylint: enable=too-many-public-methods
# pylint: enable=too-many-instance-attributes
Loading

0 comments on commit 33dc63c

Please sign in to comment.