Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes in Temporal Data to support a new Temporal Data Loader #3985

Merged
merged 48 commits into from
Mar 5, 2022
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
55f73e0
Refactor TemporalData class to inherit from BaseData
otaviocx Jan 7, 2022
4a81ed7
Merge branch 'master' of https://github.com/pyg-team/pytorch_geometri…
otaviocx Jan 15, 2022
2656004
Fixes to get TemporalData working
otaviocx Jan 15, 2022
1ec6216
Small fixes in __delitem__ of TemporalData
otaviocx Jan 15, 2022
30f6457
Add batch, __cat_dim__ and __inc__ to TemporalData
otaviocx Jan 16, 2022
eb041d5
Add Docs to TemporalData
otaviocx Jan 16, 2022
28de06d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2022
5bcd8bb
Lint fixes
otaviocx Jan 16, 2022
fe87992
Merge branch 'refactor-temporal-data' of https://github.com/otaviocx/…
otaviocx Jan 16, 2022
be4cf94
Add removed method TemporalData.seq_batches
otaviocx Jan 16, 2022
5d7b265
Update torch_geometric/data/temporal.py
otaviocx Jan 17, 2022
a75ccab
Update torch_geometric/data/temporal.py
otaviocx Jan 17, 2022
218d783
Changes requested in review
otaviocx Jan 17, 2022
281fec7
Removing trailing whitespace
otaviocx Jan 17, 2022
9d40b9e
fix doc + some inheritance issues
rusty1s Jan 18, 2022
75416fc
fix iter
rusty1s Jan 18, 2022
785023f
Add the new TemporalDataset class and refactor and
otaviocx Feb 2, 2022
b512d48
Merge branch 'pyg-master' into temporal-dataset-and-data-loader
otaviocx Feb 2, 2022
bf3816e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 2, 2022
76a0bda
Lint fixes
otaviocx Feb 2, 2022
da08a16
Lint fixes
otaviocx Feb 2, 2022
1cc55f0
Fix tests
otaviocx Feb 2, 2022
0e989ed
Add TemporalDataLoader
otaviocx Feb 3, 2022
e56b7ca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 3, 2022
9c459ba
Merge branch 'master' into temporal-dataset-and-data-loader
otaviocx Feb 3, 2022
28d81e9
Fix imports
otaviocx Feb 3, 2022
7b8e5ec
Merge branch 'temporal-dataset-and-data-loader' of https://github.com…
otaviocx Feb 3, 2022
f60fce5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 3, 2022
cfdc860
Merge branch 'master' into temporal-dataset-and-data-loader
otaviocx Feb 7, 2022
160b58e
Merge branch 'master' into temporal-dataset-and-data-loader
otaviocx Feb 25, 2022
eaae1d0
Fixes suggested in code review.
otaviocx Feb 26, 2022
8c6cade
Merge branch 'temporal-dataset-and-data-loader' of https://github.com…
otaviocx Feb 26, 2022
c9bdd1f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 26, 2022
e40fc20
Lint fix
otaviocx Feb 26, 2022
879f729
Merge branch 'temporal-dataset-and-data-loader' of https://github.com…
otaviocx Feb 26, 2022
2327db7
Update docs
otaviocx Feb 26, 2022
70f6bce
Merge branch 'master' into temporal-dataset-and-data-loader
rusty1s Feb 26, 2022
90481be
Update torch_geometric/data/temporal.py
otaviocx Mar 3, 2022
85d017b
Update torch_geometric/loader/temporal_dataloader.py
otaviocx Mar 3, 2022
7ae93b0
Add __init_ for TemporalDataLoader and update docs.
otaviocx Mar 3, 2022
2548a22
update example
rusty1s Mar 5, 2022
e35650d
update dataloader
rusty1s Mar 5, 2022
3876223
update data
rusty1s Mar 5, 2022
bc9a458
update data (part 2)
rusty1s Mar 5, 2022
e3cff36
update data (part 3)
rusty1s Mar 5, 2022
4689007
bugfix
rusty1s Mar 5, 2022
fd2196a
temporal dataloader test
rusty1s Mar 5, 2022
46bf628
Merge branch 'master' into temporal-dataset-and-data-loader
rusty1s Mar 5, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions examples/tgn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.nn import Linear

from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TGNMemory, TransformerConv
from torch_geometric.nn.models.tgn import (IdentityMessage, LastAggregator,
LastNeighborLoader)
Expand All @@ -26,14 +27,17 @@

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'JODIE')
dataset = JODIEDataset(path, name='wikipedia')
data = dataset[0].to(device)
data = dataset.data
rusty1s marked this conversation as resolved.
Show resolved Hide resolved

# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())

train_data, val_data, test_data = data.train_val_test_split(
val_ratio=0.15, test_ratio=0.15)

train_loader = TemporalDataLoader(train_data, batch_size=200)
val_loader = TemporalDataLoader(val_data, batch_size=200)
test_loader = TemporalDataLoader(test_data, batch_size=200)

neighbor_loader = LastNeighborLoader(data.num_nodes, size=10, device=device)


Expand Down Expand Up @@ -103,7 +107,8 @@ def train():
neighbor_loader.reset_state() # Start with an empty graph.

total_loss = 0
for batch in train_data.seq_batches(batch_size=200):
for batch in train_loader:
batch = batch.to(device)
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
optimizer.zero_grad()

src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
Expand All @@ -118,7 +123,8 @@ def train():

# Get updated memory of all nodes involved in the computation.
z, last_update = memory(n_id)
z = gnn(z, last_update, edge_index, data.t[e_id], data.msg[e_id])
z = gnn(z, last_update, edge_index, data.t[e_id].to(device),
data.msg[e_id].to(device))
rusty1s marked this conversation as resolved.
Show resolved Hide resolved

pos_out = link_pred(z[assoc[src]], z[assoc[pos_dst]])
neg_out = link_pred(z[assoc[src]], z[assoc[neg_dst]])
Expand All @@ -139,15 +145,16 @@ def train():


@torch.no_grad()
def test(inference_data):
def test(inference_data_loader):
memory.eval()
gnn.eval()
link_pred.eval()

torch.manual_seed(12345) # Ensure deterministic sampling across epochs.

aps, aucs = [], []
for batch in inference_data.seq_batches(batch_size=200):
for batch in inference_data_loader:
batch = batch.to(device)
src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg

neg_dst = torch.randint(min_dst_idx, max_dst_idx + 1, (src.size(0), ),
Expand All @@ -158,7 +165,8 @@ def test(inference_data):
assoc[n_id] = torch.arange(n_id.size(0), device=device)

z, last_update = memory(n_id)
z = gnn(z, last_update, edge_index, data.t[e_id], data.msg[e_id])
z = gnn(z, last_update, edge_index, data.t[e_id].to(device),
data.msg[e_id].to(device))

pos_out = link_pred(z[assoc[src]], z[assoc[pos_dst]])
neg_out = link_pred(z[assoc[src]], z[assoc[neg_dst]])
Expand All @@ -180,7 +188,7 @@ def test(inference_data):
for epoch in range(1, 51):
loss = train()
print(f' Epoch: {epoch:02d}, Loss: {loss:.4f}')
val_ap, val_auc = test(val_data)
test_ap, test_auc = test(test_data)
val_ap, val_auc = test(val_loader)
test_ap, test_auc = test(test_loader)
print(f' Val AP: {val_ap:.4f}, Val AUC: {val_auc:.4f}')
print(f'Test AP: {test_ap:.4f}, Test AUC: {test_auc:.4f}')
11 changes: 5 additions & 6 deletions torch_geometric/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import torch
from torch import Tensor

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.collate import collate
from torch_geometric.data.data import BaseData, Data
from torch_geometric.data.dataset import IndexType
from torch_geometric.data.separate import separate

Expand Down Expand Up @@ -54,7 +54,7 @@ class Batch(metaclass=DynamicInheritance):
:obj:`batch`, which maps each node to its respective graph identifier.
"""
@classmethod
def from_data_list(cls, data_list: Union[List[Data], List[HeteroData]],
def from_data_list(cls, data_list: List[BaseData],
follow_batch: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None):
r"""Constructs a :class:`~torch_geometric.data.Batch` object from a
Expand All @@ -80,7 +80,7 @@ def from_data_list(cls, data_list: Union[List[Data], List[HeteroData]],

return batch

def get_example(self, idx: int) -> Union[Data, HeteroData]:
def get_example(self, idx: int) -> BaseData:
r"""Gets the :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` object at index :obj:`idx`.
The :class:`~torch_geometric.data.Batch` object must have been created
Expand All @@ -103,8 +103,7 @@ def get_example(self, idx: int) -> Union[Data, HeteroData]:

return data

def index_select(self,
idx: IndexType) -> Union[List[Data], List[HeteroData]]:
def index_select(self, idx: IndexType) -> List[BaseData]:
r"""Creates a subset of :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` objects from specified
indices :obj:`idx`.
Expand Down Expand Up @@ -152,7 +151,7 @@ def __getitem__(self, idx: Union[int, np.integer, str, IndexType]) -> Any:
else:
return self.index_select(idx)

def to_data_list(self) -> Union[List[Data], List[HeteroData]]:
def to_data_list(self) -> List[BaseData]:
r"""Reconstructs the list of :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` objects from the
:class:`~torch_geometric.data.Batch` object.
Expand Down
42 changes: 24 additions & 18 deletions torch_geometric/data/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,21 @@ def __prepare_non_str_idx(idx):
f'{type(idx).__name__}).')
return idx

def __generate_item(self, idx):
data = {}
num_events = self.num_events
for key, item in self._store.items():
if item.size(0) == num_events:
data[key] = item[idx]
return TemporalData(**data)
rusty1s marked this conversation as resolved.
Show resolved Hide resolved

def __getitem__(self, idx: Any) -> Any:
if isinstance(idx, str):
return self._store[idx]

prepared_idx = self.__prepare_non_str_idx(idx)

data = copy.copy(self)
for key, item in data:
if item.size(0) == self.num_events:
data[key] = item[prepared_idx]
return data
return self.__generate_item(prepared_idx)

def __setitem__(self, key, value):
"""Sets the attribute :obj:`key` to :obj:`value`."""
Expand All @@ -133,7 +137,7 @@ def __delitem__(self, idx):

prepared_idx = self.__prepare_non_str_idx(idx)

for key, item in self:
for key, item in self._store.items():
if item.shape[0] == self.num_events:
del item[prepared_idx]

Expand All @@ -153,8 +157,11 @@ def __delattr__(self, key: str):
delattr(self._store, key)

def __iter__(self) -> Iterable:
for key, value in self._store.items():
yield key, value
for idx, _ in enumerate(self.src):
yield self.__generate_item(torch.tensor([idx]))

def __len__(self):
return len(self.src)

def __call__(self, *args: List[str]) -> Iterable:
for key, value in self._store.items(*args):
Expand Down Expand Up @@ -231,12 +238,7 @@ def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any:
return 0

def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any:
if 'batch' in key:
return int(value.max()) + 1
elif key in ['src', 'dst']:
return self.num_nodes
else:
return 0
return 0
rusty1s marked this conversation as resolved.
Show resolved Hide resolved

def __repr__(self) -> str:
cls = self.__class__.__name__
Expand All @@ -247,6 +249,14 @@ def __repr__(self) -> str:

def train_val_test_split(self, val_ratio: float = 0.15,
test_ratio: float = 0.15):
r"""Split the data in 3 parts: training, validation and test.
otaviocx marked this conversation as resolved.
Show resolved Hide resolved

Args:
val_ratio (float, optional): The proportion (in percents) of the
dataset to include in the validation split. (default: `0.15`)
test_ratio (float, optional): The proportion (in percents) of the
dataset to include in the test split. (default: `0.15`)
"""
val_time, test_time = np.quantile(
self.t.cpu().numpy(),
[1. - val_ratio - test_ratio, 1. - test_ratio])
Expand All @@ -256,10 +266,6 @@ def train_val_test_split(self, val_ratio: float = 0.15,

return self[:val_idx], self[val_idx:test_idx], self[test_idx:]

def seq_batches(self, batch_size: int):
for start in range(0, self.num_events, batch_size):
yield self[start:start + batch_size]

###########################################################################

def coalesce(self):
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .data_list_loader import DataListLoader
from .dense_data_loader import DenseDataLoader
from .neighbor_sampler import NeighborSampler
from .temporal_dataloader import TemporalDataLoader

__all__ = [
'DataLoader',
Expand All @@ -25,6 +26,7 @@
'DataListLoader',
'DenseDataLoader',
'NeighborSampler',
'TemporalDataLoader',
]

classes = __all__
5 changes: 3 additions & 2 deletions torch_geometric/loader/data_list_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import torch

from torch_geometric.data import Data, Dataset, HeteroData
from torch_geometric.data import Dataset
from torch_geometric.data.data import BaseData


def collate_fn(data_list):
Expand Down Expand Up @@ -30,7 +31,7 @@ class DataListLoader(torch.utils.data.DataLoader):
:class:`torch.utils.data.DataLoader`, such as :obj:`drop_last` or
:obj:`num_workers`.
"""
def __init__(self, dataset: Union[Dataset, List[Data], List[HeteroData]],
def __init__(self, dataset: Union[Dataset, List[BaseData]],
batch_size: int = 1, shuffle: bool = False, **kwargs):
if 'collate_fn' in kwargs:
del kwargs['collate_fn']
Expand Down
7 changes: 4 additions & 3 deletions torch_geometric/loader/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import torch.utils.data
from torch.utils.data.dataloader import default_collate

from torch_geometric.data import Batch, Data, Dataset, HeteroData
from torch_geometric.data import Batch, Dataset
from torch_geometric.data.data import BaseData


class Collater:
Expand All @@ -14,7 +15,7 @@ def __init__(self, follow_batch, exclude_keys):

def __call__(self, batch):
elem = batch[0]
if isinstance(elem, (Data, HeteroData)):
if isinstance(elem, BaseData):
return Batch.from_data_list(batch, self.follow_batch,
self.exclude_keys)
elif isinstance(elem, torch.Tensor):
Expand Down Expand Up @@ -59,7 +60,7 @@ class DataLoader(torch.utils.data.DataLoader):
"""
def __init__(
self,
dataset: Union[Dataset, List[Data], List[HeteroData]],
dataset: Union[Dataset, List[BaseData]],
batch_size: int = 1,
shuffle: bool = False,
follow_batch: Optional[List[str]] = None,
Expand Down
43 changes: 43 additions & 0 deletions torch_geometric/loader/temporal_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from torch.utils.data import _utils
from torch.utils.data.dataloader import _BaseDataLoaderIter

from torch_geometric.loader import DataLoader


class _SingleProcessTemporalDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessTemporalDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0

def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset[index]
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data


class TemporalDataLoader(DataLoader):
r"""A data loader which merges data objects from a
:class:`torch_geometric.data.TemporalData` to a mini-batch.

Args:
dataset (TemporalData): The :obj:`temporalData` from which to load
otaviocx marked this conversation as resolved.
Show resolved Hide resolved
the data.
batch_size (int, optional): How many samples per batch to load.
(default: :obj:`1`)
shuffle (bool, optional): If set to :obj:`True`, the data will be
reshuffled at every epoch. (default: :obj:`False`)
follow_batch (List[str], optional): Creates assignment batch
vectors for each key in the list. (default: :obj:`None`)
exclude_keys (List[str], optional): Will exclude each key in the
list. (default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch.utils.data.DataLoader`.
"""
def __iter__(self) -> '_BaseDataLoaderIter':
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
if self.num_workers == 0:
return _SingleProcessTemporalDataLoaderIter(self)
else:
return super().__iter__()
6 changes: 3 additions & 3 deletions torch_geometric/profile/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import subprocess as sp
import sys
from collections.abc import Mapping, Sequence
from typing import Any, Tuple, Union
from typing import Any, Tuple

import torch
from torch import Tensor
from torch_sparse import SparseTensor

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.data import BaseData


def count_parameters(model: torch.nn.Module) -> int:
Expand All @@ -36,7 +36,7 @@ def get_model_size(model: torch.nn.Module) -> int:
return model_size


def get_data_size(data: Union[Data, HeteroData]) -> int:
def get_data_size(data: BaseData) -> int:
r"""Given a :class:`torch_geometric.data.Data` object, get its theoretical
memory usage in bytes.

Expand Down