Skip to content

Commit

Permalink
datasets: add module and initial tests
Browse files Browse the repository at this point in the history
  • Loading branch information
abhidg committed Nov 12, 2024
1 parent 87d6427 commit 5857bac
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 45 deletions.
6 changes: 6 additions & 0 deletions l2gv2/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
"""
Datasets loader for l2gv2
"""

from ._base import DataLoader

__all__ = ["DataLoader"]
135 changes: 90 additions & 45 deletions l2gv2/datasets/_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""
Datasets loader for l2gv2, main file
"""

import datetime
from pathlib import Path

from tqdm import tqdm
Expand All @@ -9,7 +14,6 @@

DATA_PATH = Path(__file__).parent / "data"
DATASETS = [x.stem for x in DATA_PATH.glob("*") if x.is_dir()]
FORMATS = ["networkx", "tgeometric", "raphtory", "polars", "edge_list"]
EDGE_COLUMNS = {"source", "dest"} # required columns

EdgeList = list[tuple[str, str]]
Expand All @@ -20,25 +24,30 @@ def is_graph_dataset(p: Path) -> bool:
return (p / (p.stem + "_edges.parquet")).exists()


class DataLoader:
class DataLoader: # pylint: disable=too-many-instance-attributes
"""Take a dataframe representing a (temporal) graph and provide
methods for loading the data in different formats.
"""

def __init__(self, dset: str):
def __init__(self, dset: str | Path, timestamp_fmt: str = "%Y-%m-%d"):
if is_graph_dataset(Path(dset)):
self.path = Path(dset)
elif is_graph_dataset(DATA_PATH / dset):
self.path = DATA_PATH / dset
else:
raise FileNotFoundError(f"Dataset not found: {dset}")
raise ValueError(f"Dataset either invalid or not found: {dset}")

self.timestamp_fmt = timestamp_fmt
self.paths = {"edges": self.path / (self.path.stem + "_edges.parquet")}
if (nodes_path := self.path / (self.path.stem + "_nodes.parquet")).exists():
self.paths["nodes"] = nodes_path

self._load_files()

def timestamp_from_string(self, ts: str) -> datetime.datetime:
"Returns timestamp from string using `timestamp_fmt`"
return datetime.datetime.strptime(ts, self.timestamp_fmt)

def _load_files(self):
"Loads dataset into memory"

Expand All @@ -49,6 +58,11 @@ def _load_files(self):
self.temporal = "timestamp" in self.edges.columns
if not self.temporal:
self.edges = self.edges.with_columns(pl.lit(0).alias("timestamp"))
else: # convert timestamp to datetime format
self.edges = self.edges.with_columns(
pl.col("timestamp").str.to_datetime(self.timestamp_fmt)
)

self.datelist = self.edges.select("timestamp").to_series().unique()

# Process nodes
Expand All @@ -57,7 +71,17 @@ def _load_files(self):
assert (
"nodes" in self.nodes.columns
), "Required node columns not found: 'nodes'"
if self.temporal:
if "timestamp" not in self.nodes.columns:
raise ValueError(
"Nodes dataset missing 'timestamp' column, required"
" when edges dataset has 'timestamp'"
)
self.nodes = self.nodes.with_columns(
pl.col("timestamp").str.to_datetime(self.timestamp_fmt)
)
else:
# build nodes from edges dataset
self.nodes = (
pl.concat(
[
Expand All @@ -83,31 +107,52 @@ def _load_files(self):
]

def get_dates(self) -> list[str]:
"Returns list of dates"
return self.datelist.to_list()

def get_edges(self) -> pl.DataFrame:
"Returns edges as a polars DataFrame"
return self.edges

def get_nodes(self, ts: str | None = None) -> pl.DataFrame:
"""Returns node data as a polars DataFrame
Args:
ts (str, optional): if specified, only return nodes with this timestamp
Returns:
polars.DataFrame
"""
if ts is None:
return self.nodes
else:
return self.nodes.filter(pl.col("timestamp") == ts)
return self.nodes.filter(pl.col("timestamp") == self.timestamp_from_string(ts))

def get_node_list(self, ts: str | None = None) -> list[str]:
"""Returns node list
Args:
ts (str, optional): if specified, only return nodes with this timestamp
Returns:
list of str
"""
nodes = self.nodes
if ts is not None:
nodes = nodes.filter(pl.col("timestamp") == ts)
nodes = nodes.filter(pl.col("timestamp") == self.timestamp_from_string(ts))
return nodes.select("nodes").unique(maintain_order=True).to_series().to_list()

def get_node_features(self) -> list[str]:
"Returns node features as a list of strings"
return self.node_features

def get_edge_features(self) -> list[str]:
"Returns edge features as a list of strings"
return self.edge_features

def get_graph(self) -> rp.Graph:
g = rp.Graph()
def get_graph(self) -> rp.Graph: # pylint: disable=no-member
"Returns a raphtory.Graph representation"
g = rp.Graph() # pylint: disable=no-member

g.load_edges_from_pandas(
df=self.edges.to_pandas(),
time="timestamp",
Expand All @@ -121,9 +166,17 @@ def get_graph(self) -> rp.Graph:
id="nodes",
properties=self.node_features,
)

return g

def get_edge_list(self, temp: bool = True) -> EdgeList | dict[str, EdgeList]:
def get_edge_list(self, temp: bool = True) -> EdgeList | dict[datetime.datetime, EdgeList]:
"""Returns edge list
Args:
temp (bool, optional, default=True): If true, then returns a dictionary of
timestamps to edge lists (list of string tuples), if false, returns
edge list for the entire graph
"""
if self.temporal and temp:
edge_list = {}
for d in tqdm(self.datelist):
Expand All @@ -138,7 +191,15 @@ def get_edge_list(self, temp: bool = True) -> EdgeList | dict[str, EdgeList]:
edge_list = [tuple(x) for x in edges]
return edge_list

def get_networkx(self, temp: bool = True) -> nx.Graph | dict[str, nx.Graph]:
def get_networkx(self, temp: bool = True) -> nx.DiGraph | dict[datetime.datetime, nx.DiGraph]:
"""Returns networkx.DiGraph representation
Args:
temp (bool, optional, default=True): If true, then returns a dictionary of
timestamps to networkx digraphs, if false, returns a networkx digraph
"""


if self.temporal and temp:
nx_graphs = {}
for d in tqdm(self.datelist):
Expand All @@ -148,16 +209,23 @@ def get_networkx(self, temp: bool = True) -> nx.Graph | dict[str, nx.Graph]:
.to_numpy()
)
edge_list = [tuple(x) for x in edges]
nx_graphs[d] = nx.from_edgelist(edge_list)
nx_graphs[d] = nx.from_edgelist(edge_list, create_using=nx.DiGraph)
else:
edges = self.edges.select("source", "dest").unique().to_numpy()
edge_list = [tuple(x) for x in edges]
nx_graphs = nx.from_edgelist(edge_list)
nx_graphs = nx.from_edgelist(edge_list, create_using=nx.DiGraph)
return nx_graphs

def get_edge_index(
self, temp: bool = True
) -> torch.Tensor | dict[str, torch.Tensor]:
"""Returns edge index as torch tensors
Args:
temp (bool, optional, default=True): If true, then returns a dictionary of
timestamps to torch tensors (list of string tuples), if false, returns
a torch tensor.
"""
if self.temporal and temp:
edge_index = {}
for d in tqdm(self.datelist):
Expand All @@ -179,8 +247,15 @@ def get_edge_index(
def get_tgeometric(
self, temp: bool = True
) -> torch_geometric.data.Data | dict[str, torch_geometric.data.Data]:
"""Returns torch_geometric representation
Args:
temp (bool, optional, default=True): If true, then returns a dictionary of
timestamps to torch_geometric representations, if false, returns
a torch_geometric representation.
"""
nodes = self.nodes.select("nodes").unique().to_numpy()
features = self.nodes.select([c for c in self.node_features]).to_numpy()
features = self.nodes.select(self.node_features).to_numpy()
if self.temporal and temp:
tg_graphs = {}
for d in tqdm(self.datelist):
Expand All @@ -204,34 +279,4 @@ def get_tgeometric(
tg_graphs.x = torch.from_numpy(features).float()
return tg_graphs

def summary(self):
pass


# TODO: uncomment and integrate into DataLoader?
# def summary(data: nx.Graph | list[nx.Graph]):
# if not isinstance(data, list):
# data = [data]
#
# if isinstance(data[0], nx.Graph):
# number_of_nodes = []
# number_of_edges = []
# avg_degree = []
# for G in data:
# number_of_nodes.append(G.number_of_nodes())
# number_of_edges.append(G.number_of_edges())
#
# avg_degree.append(sum(dict(G.degree()).values()) / G.number_of_nodes())
# avg_n = int(np.mean(number_of_nodes))
# avg_e = int(np.mean(number_of_edges))
# avg_d = int(np.mean(avg_degree))
#
# print(
# "The dataset consists of {} graphs.\n Average number of nodes : {} \n Average number of edges : {} \n Average degree: {} ".format(
# len(data), avg_n, avg_e, avg_d
# )
# )
# else:
# print("Converting into networkx graphs")
# data = [to_networkx(d, to_undirected=False) for d in tqdm(data)]
# summary(data)
# TODO: integrate summary() into DataLoader
Binary file added tests/datasets/invalid/invalid_nodes.parquet
Binary file not shown.
Binary file added tests/datasets/social/social_edges.parquet
Binary file not shown.
Binary file added tests/datasets/social/social_nodes.parquet
Binary file not shown.
87 changes: 87 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1 +1,88 @@
"""
Test functions for l2gv2.datasets
"""

from pathlib import Path

import pytest

from l2gv2.datasets import DataLoader

TEST_DATASETS = Path(__file__).parent / "datasets"
SOCIAL_NODES = [
"amy",
"anil",
"charlie",
"john",
"maria",
"peter",
]
SOCIAL_EDGES = [
("amy", "charlie"),
("amy", "peter"),
("anil", "maria"),
("john", "amy"),
("maria", "peter"),
("peter", "anil"),
]

# disable missing-function-docstring, redefined-outer-name (pytest fixtures)
# pylint: disable=C0116,W0621

def test_dataloader_invalid_dataset():
with pytest.raises(ValueError):
DataLoader(TEST_DATASETS / "invalid")


@pytest.fixture
def social_dataset():
return DataLoader(TEST_DATASETS / "social")


def test_is_temporal(social_dataset):
assert social_dataset.temporal is True


def test_features(social_dataset):
assert social_dataset.node_features == ["height_cm"]
assert social_dataset.edge_features == ["distance_house_km"]


def test_get_dates(social_dataset):
assert sorted(dt.date().isoformat() for dt in social_dataset.get_dates()) == [
"2024-05-02",
"2024-05-03",
]


def test_get_edges(social_dataset):
edges = social_dataset.get_edges()
source_nodes = edges.select("source").to_series().to_list()
target_nodes = edges.select("dest").to_series().to_list()
assert sorted(zip(source_nodes, target_nodes)) == SOCIAL_EDGES


def test_get_nodes(social_dataset):
nodes = social_dataset.get_nodes()
assert sorted(nodes.select("nodes").to_series().to_list()) == SOCIAL_NODES


def test_get_graph(social_dataset):
rpgraph = social_dataset.get_graph() # raphtory graph format
assert sorted(rpgraph.nodes.to_df().name.to_list()) == SOCIAL_NODES
edge_df = rpgraph.edges.to_df()[["src", "dst"]]
assert sorted(edge_df.itertuples(index=False)) == SOCIAL_EDGES


def test_get_edge_list(social_dataset):
edges = sorted(social_dataset.get_edge_list(temp=False))
assert edges == SOCIAL_EDGES


def test_get_networkx(social_dataset):
nxgraph = social_dataset.get_networkx(temp=False)
assert sorted(nxgraph.nodes) == SOCIAL_NODES
assert sorted(nxgraph.edges) == SOCIAL_EDGES


# TODO: add node and edge testing for torch_geometric

0 comments on commit 5857bac

Please sign in to comment.