Skip to content

Commit

Permalink
[Type Hints] datasets.ICEWS18 (pyg-team#5666)
Browse files Browse the repository at this point in the history
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
2 people authored and JakubPietrakIntel committed Nov 25, 2022
1 parent 041377b commit a849215
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 14 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666))
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))
Expand Down
39 changes: 26 additions & 13 deletions torch_geometric/datasets/icews.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,33 @@
from typing import Callable, List, Optional

import torch

from torch_geometric.data import Data, InMemoryDataset, download_url
from torch_geometric.io import read_txt_array


class EventDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None,
pre_filter=None):
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
):
super().__init__(root, transform, pre_transform, pre_filter)

@property
def num_nodes(self):
def num_nodes(self) -> int:
raise NotImplementedError

@property
def num_rels(self):
def num_rels(self) -> int:
raise NotImplementedError

def process_events(self):
def process_events(self) -> int:
raise NotImplementedError

def process(self):
def process(self) -> List[Data]:
events = self.process_events()
events = events - events.min(dim=0, keepdim=True)[0]

Expand Down Expand Up @@ -64,34 +71,40 @@ class ICEWS18(EventDataset):
url = 'https://github.com/INK-USC/RE-Net/raw/master/data/ICEWS18'
splits = [0, 373018, 419013, 468558] # Train/Val/Test splits.

def __init__(self, root, split='train', transform=None, pre_transform=None,
pre_filter=None):
def __init__(
self,
root: str,
split: str = 'train',
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
):
assert split in ['train', 'val', 'test']
super().__init__(root, transform, pre_transform, pre_filter)
idx = self.processed_file_names.index(f'{split}.pt')
self.data, self.slices = torch.load(self.processed_paths[idx])

@property
def num_nodes(self):
def num_nodes(self) -> int:
return 23033

@property
def num_rels(self):
def num_rels(self) -> int:
return 256

@property
def raw_file_names(self):
def raw_file_names(self) -> List[str]:
return [f'{name}.txt' for name in ['train', 'valid', 'test']]

@property
def processed_file_names(self):
def processed_file_names(self) -> List[str]:
return ['train.pt', 'val.pt', 'test.pt']

def download(self):
for filename in self.raw_file_names:
download_url(f'{self.url}/{filename}', self.raw_dir)

def process_events(self):
def process_events(self) -> torch.Tensor:
events = []
for path in self.raw_paths:
data = read_txt_array(path, sep='\t', end=4, dtype=torch.long)
Expand Down

0 comments on commit a849215

Please sign in to comment.