Skip to content

Commit

Permalink
Merge pull request #35 from OxfordRSE/update_typing
Browse files Browse the repository at this point in the history
Update typing
  • Loading branch information
mihaeladuta authored Dec 2, 2024
2 parents 4a50d39 + 7685018 commit cef26e7
Show file tree
Hide file tree
Showing 48 changed files with 636 additions and 526 deletions.
2 changes: 1 addition & 1 deletion .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
"linkToUsage": true,
"skipCi": true,
"contributors": []
}
}
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
"torch_scatter",
"local2global",
"raphtory",
"local2global_embedding",
"networkx",
"matplotlib",
"nfts",
Expand Down
2 changes: 1 addition & 1 deletion l2gv2/anomaly_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any
import numpy as np
from l2gv2.patch.patch import Patch
from .patch.patch import Patch


def raw_anomaly_score_node_patch(aligned_patch_emb, emb, node) -> np.floating[Any]:
Expand Down
41 changes: 27 additions & 14 deletions l2gv2/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import numpy as np
import numba

from l2gv2.network import TGraph, NPGraph
from l2gv2 import progress
from network import TGraph, NPGraph
from . import progress


def distributed_clustering(graph: TGraph, beta, rounds=None, patience=3, min_samples=2):
Expand Down Expand Up @@ -98,7 +98,7 @@ def fennel_clustering(
num_iters=1,
clusters=None,
):
""" TODO: docstring for fennel_clustering. """
"""TODO: docstring for fennel_clustering."""
graph = graph.to(NPGraph)

if clusters is None:
Expand Down Expand Up @@ -126,6 +126,7 @@ def fennel_clustering(
)
return torch.as_tensor(clusters)


# pylint: disable=too-many-branches
@numba.njit
def _fennel_clustering(
Expand All @@ -148,18 +149,18 @@ def _fennel_clustering(
graph: input graph
num_clusters: target number of clusters
load_limit: maximum cluster size is
load_limit: maximum cluster size is
``load_limit * graph.num_nodes / num_clusters`` (default: ``1.1``)
alpha: :math:`\alpha` value for the algorithm (default as suggested in [#fennel]_)
gamma: :math:`\gamma` value for the algorithm (default: 1.5)
randomise_order: if ``True``, randomise order, else use breadth-first-search order.
clusters: input clustering to refine (optional)
num_iters: number of cluster assignment iterations (default: ``1``)
Returns:
Expand Down Expand Up @@ -240,8 +241,11 @@ def _fennel_clustering(
progress.close(pbar)

return clusters


# pylint: enable=too-many-branches


def louvain_clustering(graph: TGraph, *args, **kwargs):
r"""
Implements clustering using the Louvain [#l]_ algorithm for modularity optimisation
Expand Down Expand Up @@ -300,7 +304,7 @@ def metis_clustering(graph: TGraph, num_clusters):


def spread_clustering(graph, num_clusters, max_degree_init=True):
""" TODO: docstring for spread_clustering. """
"""TODO: docstring for spread_clustering."""

clusters = torch.full((graph.num_nodes,), -1, dtype=torch.long, device=graph.device)
if max_degree_init:
Expand Down Expand Up @@ -375,14 +379,22 @@ def spread_clustering(graph, num_clusters, max_degree_init=True):


def hierarchical_aglomerative_clustering(
graph, method=spread_clustering, levels=None, branch_factors=None
graph,
method=spread_clustering,
levels: int | None = None,
branch_factors: int | list | None = None,
):
""" TODO: docstring for hierarchical_aglomerative_clustering. """
"""TODO: docstring for hierarchical_aglomerative_clustering."""

if branch_factors is None:
if levels is None:
raise ValueError("both levels and branch_factors are None")

branch_factors = [graph.num_nodes ** (1 / (levels + 1)) for _ in range(levels)]
else:
if not isinstance(branch_factors, Iterable):
if levels is None:
raise ValueError("branch_factors is not Iterable and levels is None")
branch_factors = [branch_factors] * (levels)
else:
if levels is None:
Expand All @@ -400,7 +412,8 @@ def hierarchical_aglomerative_clustering(


class Partition(Sequence):
""" TODO: docstring for Partition. """
"""TODO: docstring for Partition."""

def __init__(self, partition_tensor):
partition_tensor = torch.as_tensor(partition_tensor)
counts = torch.bincount(partition_tensor)
Expand Down
9 changes: 3 additions & 6 deletions l2gv2/datasets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

import datetime
from typing import Union
from pathlib import Path

from tqdm import tqdm
Expand Down Expand Up @@ -198,7 +197,7 @@ def get_edge_list(

def get_networkx(
self, temp: bool = True
) -> Union[nx.Graph, dict[datetime.datetime, nx.Graph]]:
) -> nx.Graph | dict[datetime.datetime, nx.Graph]:
"""Returns networkx.DiGraph representation
Args:
Expand All @@ -223,7 +222,7 @@ def get_networkx(

def get_edge_index(
self, temp: bool = True
) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
) -> torch.Tensor | dict[str, torch.Tensor]:
"""Returns edge index as torch tensors
Args:
Expand Down Expand Up @@ -251,9 +250,7 @@ def get_edge_index(

def get_tgeometric(
self, temp: bool = True
) -> Union[
torch_geometric.data.Data, dict[datetime.datetime, torch_geometric.data.Data]
]:
) -> torch_geometric.data.Data | dict[datetime.datetime, torch_geometric.data.Data]:
"""Returns torch_geometric representation
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,4 +248,4 @@ Western Sahara,EH,ESH,732,ISO 3166-2:EH,Africa,Northern Africa,,002,015,
Yemen,YE,YEM,887,ISO 3166-2:YE,Asia,Western Asia,,142,145,
Zambia,ZM,ZMB,894,ISO 3166-2:ZM,Africa,Sub-Saharan Africa,Eastern Africa,002,202,014
Zimbabwe,ZW,ZWE,716,ISO 3166-2:ZW,Africa,Sub-Saharan Africa,Eastern Africa,002,202,014
NA,ZZ,NA,0,NA,NA,NA,NA,0,0,0
NA,ZZ,NA,0,NA,NA,NA,NA,0,0,0
2 changes: 1 addition & 1 deletion l2gv2/datasets/data/nas/country_codes.csv
Original file line number Diff line number Diff line change
Expand Up @@ -248,4 +248,4 @@ Western Sahara,EH,ESH,732,ISO 3166-2:EH,Africa,Northern Africa,,002,015,
Yemen,YE,YEM,887,ISO 3166-2:YE,Asia,Western Asia,,142,145,
Zambia,ZM,ZMB,894,ISO 3166-2:ZM,Africa,Sub-Saharan Africa,Eastern Africa,002,202,014
Zimbabwe,ZW,ZWE,716,ISO 3166-2:ZW,Africa,Sub-Saharan Africa,Eastern Africa,002,202,014
NA,ZZ,NA,0,NA,NA,NA,NA,0,0,0
NA,ZZ,NA,0,NA,NA,NA,NA,0,0,0
5 changes: 4 additions & 1 deletion l2gv2/embedding/dgi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
""" Imports for the DGI. """
"""Imports for the DGI."""

from .models import DGI
from .utils.loss import DGILoss

__all__ = ["DGI", "DGILoss"]
3 changes: 2 additions & 1 deletion l2gv2/embedding/dgi/execute.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" TODO: module docstring for dgi/execute.py. """
"""TODO: module docstring for dgi/execute.py."""

import argparse
import torch
from torch import nn
Expand Down
5 changes: 4 additions & 1 deletion l2gv2/embedding/dgi/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
""" TODO: module docstring for dgi/layers/__init__.py. """
"""TODO: module docstring for dgi/layers/__init__.py."""

from .gcn import GCN
from .readout import AvgReadout
from .discriminator import Discriminator

__all__ = ["GCN", "AvgReadout", "Discriminator"]
10 changes: 6 additions & 4 deletions l2gv2/embedding/dgi/layers/discriminator.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
""" TODO: module docstring for dgi/layers/discriminator.py. """
"""TODO: module docstring for dgi/layers/discriminator.py."""

import torch
from torch import nn


class Discriminator(nn.Module):
""" TODO: class docstring for Discriminator. """
"""TODO: class docstring for Discriminator."""

def __init__(self, n_h):
super().__init__()
self.f_k = nn.Bilinear(n_h, n_h, 1)
self.reset_parameters()

def reset_parameters(self):
""" TODO: method docstring for Discriminator.reset_parameters. """
"""TODO: method docstring for Discriminator.reset_parameters."""
for m in self.modules():
if isinstance(m, nn.Bilinear):
torch.nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0.0)

def forward(self, c, h_pl, h_mi, s_bias1=None, s_bias2=None):
""" TODO: method docstring for Discriminator.forward. """
"""TODO: method docstring for Discriminator.forward."""
c_x = torch.unsqueeze(c, 0)
c_x = c_x.expand_as(h_pl)

Expand Down
10 changes: 6 additions & 4 deletions l2gv2/embedding/dgi/layers/gcn.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
""" TODO: module docstring for dgi/layers/gcn.py. """
"""TODO: module docstring for dgi/layers/gcn.py."""

from torch import nn
import torch_geometric.nn as tg_nn


class GCN(nn.Module):
""" TODO: class docstring for GCN. """
"""TODO: class docstring for GCN."""

def __init__(self, in_ft, out_ft, act, bias=True):
super().__init__()
self.conv = tg_nn.GCNConv(in_channels=in_ft, out_channels=out_ft, bias=bias)
self.act = nn.PReLU() if act == "prelu" else act
self.reset_parameters()

def reset_parameters(self):
""" TODO: method docstring for GCN.reset_parameters. """
"""TODO: method docstring for GCN.reset_parameters."""
self.conv.reset_parameters()
if hasattr(self.act, "reset_parameters"):
self.act.reset_parameters()
Expand All @@ -21,7 +23,7 @@ def reset_parameters(self):

# Shape of seq: (batch, nodes, features)
def forward(self, seq, adj):
""" TODO: method docstring for GCN.forward. """
"""TODO: method docstring for GCN.forward."""
out = self.conv(seq, adj)

return self.act(out)
4 changes: 3 additions & 1 deletion l2gv2/embedding/dgi/layers/readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,21 @@
import torch
from torch import nn


# TODO: fix too-few-public-methods for embedding/dgi/layers/readout.py AvgReadout
# pylint: disable=too-few-public-methods
# Applies an average on seq, of shape (batch, nodes, features)
# While taking into account the masking of msk
class AvgReadout(nn.Module):
"""TODO: class docstring for AvgReadout."""


def forward(self, seq, msk):
"""TODO: method docstring for AvgReadout.forward."""
if msk is None:
return torch.mean(seq, 0)

msk = torch.unsqueeze(msk, -1)
return torch.sum(seq * msk, 0) / torch.sum(msk)


# pylint: enable=too-few-public-methods
5 changes: 4 additions & 1 deletion l2gv2/embedding/dgi/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
""" TODO: module docstring for dgi/models/__init__.py. """
"""TODO: module docstring for dgi/models/__init__.py."""

from .dgi import DGI
from .logreg import LogReg

__all__ = ["DGI", "LogReg"]
12 changes: 7 additions & 5 deletions l2gv2/embedding/dgi/models/dgi.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
""" TODO: module docstring for dgi/models/dgi.py. """
"""TODO: module docstring for dgi/models/dgi.py."""

from torch import nn
from ..layers import GCN, AvgReadout, Discriminator


class DGI(nn.Module):
""" TODO: class docstring for DGI. """
"""TODO: class docstring for DGI."""

def __init__(self, n_in, n_h, activation="prelu"):
super().__init__()
self.gcn = GCN(n_in, n_h, activation)
Expand All @@ -15,13 +17,13 @@ def __init__(self, n_in, n_h, activation="prelu"):
self.disc = Discriminator(n_h)

def reset_parameters(self):
""" TODO: method docstring for DGI.reset_parameters. """
"""TODO: method docstring for DGI.reset_parameters."""
for m in self.children():
if hasattr(m, "reset_parameters"):
m.reset_parameters()

def forward(self, seq1, seq2, adj, msk, samp_bias1, samp_bias2):
""" TODO: method docstring for DGI.forward. """
"""TODO: method docstring for DGI.forward."""
h_1 = self.gcn(seq1, adj)

c = self.read(h_1, msk)
Expand All @@ -35,7 +37,7 @@ def forward(self, seq1, seq2, adj, msk, samp_bias1, samp_bias2):

# Detach the return variables
def embed(self, data):
""" TODO: method docstring for DGI.embed. """
"""TODO: method docstring for DGI.embed."""
h_1 = self.gcn(data.x, data.edge_index)

return h_1.detach()
10 changes: 6 additions & 4 deletions l2gv2/embedding/dgi/models/logreg.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
""" TODO: module docstring for dgi/models/logreg.py. """
"""TODO: module docstring for dgi/models/logreg.py."""

import torch
from torch import nn


class LogReg(nn.Module):
""" TODO: class docstring for LogReg. """
"""TODO: class docstring for LogReg."""

def __init__(self, ft_in, nb_classes):
super().__init__()
self.fc = nn.Linear(ft_in, nb_classes)
Expand All @@ -13,13 +15,13 @@ def __init__(self, ft_in, nb_classes):
self.weights_init(m)

def weights_init(self, m):
""" TODO: method docstring for LogReg.weights_init. """
"""TODO: method docstring for LogReg.weights_init."""
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0.0)

def forward(self, seq):
""" TODO: method docstring for LogReg.forward """
"""TODO: method docstring for LogReg.forward"""
ret = self.fc(seq)
return ret
5 changes: 4 additions & 1 deletion l2gv2/embedding/dgi/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
""" TODO: module docstring for dgi/utils/__init__.py. """
"""TODO: module docstring for dgi/utils/__init__.py."""

from .loss import DGILoss

__all__ = ["DGILoss"]
Loading

0 comments on commit cef26e7

Please sign in to comment.