Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Summary for PyG datasets #5438

Merged
merged 12 commits into from
Sep 16, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.2.0] - 2022-MM-DD
### Added
- Added `print_summary` method for the `torch_geometric.data.Dataset` interface ([#5438](https://github.com/pyg-team/pytorch_geometric/pull/5438))
- Added `sampler` support to `LightningDataModule` ([#5456](https://github.com/pyg-team/pytorch_geometric/pull/5456), [#5457](https://github.com/pyg-team/pytorch_geometric/pull/5457))
- Added official splits to `MalNetTiny` dataset ([#5078](https://github.com/pyg-team/pytorch_geometric/pull/5078))
- Added `IndexToMask` and `MaskToIndex` transforms ([#5375](https://github.com/pyg-team/pytorch_geometric/pull/5375), [#5455](https://github.com/pyg-team/pytorch_geometric/pull/5455))
Expand Down
45 changes: 45 additions & 0 deletions test/data/test_summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch

from torch_geometric.data.summary import Summary
from torch_geometric.datasets import FakeDataset, FakeHeteroDataset
from torch_geometric.testing import withPackage


def test_summary():
dataset = FakeDataset(num_graphs=10)
num_nodes = torch.Tensor([data.num_nodes for data in dataset])
num_edges = torch.Tensor([data.num_edges for data in dataset])

summary = dataset.get_summary()

assert summary.name == 'FakeDataset'
assert summary.num_graphs == 10

assert summary.num_nodes.mean == num_nodes.mean().item()
assert summary.num_nodes.std == num_nodes.std().item()
assert summary.num_nodes.min == num_nodes.min().item()
assert summary.num_nodes.quantile25 == num_nodes.quantile(0.25).item()
assert summary.num_nodes.median == num_nodes.median().item()
assert summary.num_nodes.quantile75 == num_nodes.quantile(0.75).item()
assert summary.num_nodes.max == num_nodes.max().item()

assert summary.num_edges.mean == num_edges.mean().item()
assert summary.num_edges.std == num_edges.std().item()
assert summary.num_edges.min == num_edges.min().item()
assert summary.num_edges.quantile25 == num_edges.quantile(0.25).item()
assert summary.num_edges.median == num_edges.median().item()
assert summary.num_edges.quantile75 == num_edges.quantile(0.75).item()
assert summary.num_edges.max == num_edges.max().item()


@withPackage('tabulate')
def test_hetero_summary():
dataset1 = FakeHeteroDataset(num_graphs=10)
summary1 = Summary.from_dataset(dataset1)

dataset2 = [data.to_homogeneous() for data in dataset1]
summary2 = Summary.from_dataset(dataset2)
summary2.name = 'FakeHeteroDataset'

assert summary1 == summary2
assert str(summary1) == str(summary2)
9 changes: 9 additions & 0 deletions torch_geometric/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,15 @@ def __repr__(self) -> str:
arg_repr = str(len(self)) if len(self) > 1 else ''
return f'{self.__class__.__name__}({arg_repr})'

def get_summary(self):
r"""Collects summary statistics for the dataset."""
from torch_geometric.data.summary import Summary
return Summary.from_dataset(self)

def print_summary(self):
r"""Prints summary statistics of the dataset to the console."""
return str(self.get_summary())


def to_list(value: Any) -> Sequence:
if isinstance(value, Sequence) and not isinstance(value, str):
Expand Down
90 changes: 90 additions & 0 deletions torch_geometric/data/summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from dataclasses import dataclass
from typing import List, Optional, Union

import torch
from tqdm import tqdm

from torch_geometric.data import Dataset


@dataclass
class Stats:
mean: float
std: float
min: float
quantile25: float
median: float
quantile75: float
max: float

@classmethod
def from_data(cls, data: Union[List[int], List[float], torch.Tensor]):
if not isinstance(data, torch.Tensor):
data = torch.tensor(data)
data = data.to(torch.float)

return cls(
mean=data.mean().item(),
std=data.std().item(),
min=data.min().item(),
quantile25=data.quantile(0.25).item(),
median=data.median().item(),
quantile75=data.quantile(0.75).item(),
max=data.max().item(),
)


@dataclass(repr=False)
class Summary:
name: str
num_graphs: int
num_nodes: Stats
num_edges: Stats

@classmethod
def from_dataset(
cls,
dataset: Dataset,
progress_bar: Optional[bool] = None,
):
r"""Creates a summary of a :class:`~torch_geometric.data.Dataset`
object.

Args:
dataset (Dataset): The dataset.
progress_bar (bool, optional). If set to :obj:`True`, will show a
progress bar during stats computation. If set to :obj:`None`,
will automatically decide whether to show a progress bar based
on dataset size. (default: :obj:`None`)
"""
if progress_bar is None:
progress_bar = len(dataset) >= 10000

if progress_bar:
dataset = tqdm(dataset)

num_nodes_list, num_edges_list = [], []
for data in dataset:
num_nodes_list.append(data.num_nodes)
num_edges_list.append(data.num_edges)

return cls(
name=dataset.__class__.__name__,
num_graphs=len(dataset),
num_nodes=Stats.from_data(num_nodes_list),
num_edges=Stats.from_data(num_edges_list),
)

def __repr__(self) -> str:
from tabulate import tabulate

prefix = f'{self.name} (#graphs={self.num_graphs}):\n'

content = [['', '#nodes', '#edges']]
stats = [self.num_nodes, self.num_edges]
for field in Stats.__dataclass_fields__:
row = [field] + [f'{getattr(s, field):.1f}' for s in stats]
content.append(row)
body = tabulate(content, headers='firstrow', tablefmt='psql')

return prefix + body