From fc68ed8a8039f111f723d2500a265e5b8ad0af00 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 18 Oct 2022 11:29:44 +0000 Subject: [PATCH 1/6] update --- test/loader/test_link_neighbor_loader.py | 2 +- torch_geometric/data/lightning_datamodule.py | 9 +- torch_geometric/loader/link_loader.py | 127 +++++------------- .../loader/link_neighbor_loader.py | 19 +-- torch_geometric/loader/neighbor_loader.py | 20 ++- torch_geometric/loader/node_loader.py | 59 ++++---- torch_geometric/loader/utils.py | 26 +++- torch_geometric/sampler/base.py | 32 ++--- torch_geometric/sampler/neighbor_sampler.py | 20 ++- 9 files changed, 137 insertions(+), 177 deletions(-) diff --git a/test/loader/test_link_neighbor_loader.py b/test/loader/test_link_neighbor_loader.py index 9399b3016bd8..e09649c09847 100644 --- a/test/loader/test_link_neighbor_loader.py +++ b/test/loader/test_link_neighbor_loader.py @@ -195,7 +195,7 @@ def test_temporal_heterogeneous_link_neighbor_loader(): data['paper', 'author'].edge_index = get_edge_index(100, 200, 1000) data['author', 'paper'].edge_index = get_edge_index(200, 100, 1000) - with pytest.raises(ValueError, match=r"'edge_label_time' was not set.*"): + with pytest.raises(ValueError, match=r"'edge_label_time' is not set"): loader = LinkNeighborLoader( data, num_neighbors=[-1] * 2, diff --git a/torch_geometric/data/lightning_datamodule.py b/torch_geometric/data/lightning_datamodule.py index fce51019d0ef..225c17577fa0 100644 --- a/torch_geometric/data/lightning_datamodule.py +++ b/torch_geometric/data/lightning_datamodule.py @@ -8,16 +8,14 @@ from torch_geometric.data import Data, Dataset, HeteroData from torch_geometric.data.feature_store import FeatureStore from torch_geometric.data.graph_store import GraphStore +from torch_geometric.loader import NeighborLoader from torch_geometric.loader.dataloader import DataLoader from torch_geometric.loader.link_neighbor_loader import ( LinkNeighborLoader, get_edge_label_index, ) -from torch_geometric.loader.neighbor_loader import ( - NeighborLoader, - NeighborSampler, - get_input_nodes, -) +from torch_geometric.loader.utils import get_input_nodes +from torch_geometric.sampler import NeighborSampler from torch_geometric.typing import InputEdges, InputNodes try: @@ -313,6 +311,7 @@ def __init__( directed=kwargs.get('directed', True), input_type=get_input_nodes(data, input_train_nodes)[0], time_attr=kwargs.get('time_attr', None), + seed_time=kwargs.get('seed_time', None), is_sorted=kwargs.get('is_sorted', False), share_memory=num_workers > 0, ) diff --git a/torch_geometric/loader/link_loader.py b/torch_geometric/loader/link_loader.py index 157267b917e5..1399d9486928 100644 --- a/torch_geometric/loader/link_loader.py +++ b/torch_geometric/loader/link_loader.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Iterator, Tuple, Union +from typing import Any, Callable, Iterator, List, Tuple, Union import torch @@ -7,6 +7,7 @@ from torch_geometric.data.graph_store import GraphStore from torch_geometric.loader.base import DataLoaderIterator from torch_geometric.loader.utils import ( + InputData, filter_custom_store, filter_data, filter_hetero_data, @@ -83,59 +84,56 @@ def __init__( filter_per_worker: bool = False, **kwargs, ): - # Remove for PyTorch Lightning: - if 'dataset' in kwargs: - del kwargs['dataset'] - if 'collate_fn' in kwargs: - del kwargs['collate_fn'] + # Get edge type (or `None` for homogeneous graphs): + edge_type, edge_label_index = get_edge_label_index( + data, edge_label_index) + if edge_label is None: + edge_label = torch.zeros(edge_label_index.size(1), + device=edge_label_index.device) self.data = data - - # Initialize sampler with keyword arguments: - # NOTE sampler is an attribute of 'DataLoader', so we use link_sampler - # here: + self.edge_type = edge_type self.link_sampler = link_sampler - - # Store additional arguments: - self.edge_label = edge_label - self.edge_label_index = edge_label_index - self.edge_label_time = edge_label_time + self.input_data = InputData(edge_label_index[0], edge_label_index[1], + edge_label, edge_label_time) + self.neg_sampling_ratio = neg_sampling_ratio self.transform = transform self.filter_per_worker = filter_per_worker - self.neg_sampling_ratio = neg_sampling_ratio - # Get input type, or None for homogeneous graphs: - edge_type, edge_label_index = get_edge_label_index( - data, edge_label_index) - if edge_label is None: - edge_label = torch.zeros(edge_label_index.size(1), - device=edge_label_index.device) - self.input_type = edge_type + iterator = range(edge_label_index.size(1)) + super().__init__(iterator, collate_fn=self.collate_fn, **kwargs) - super().__init__( - Dataset(edge_label_index, edge_label, edge_label_time), - collate_fn=self.collate_fn, - **kwargs, + def collate_fn(self, index: List[int]) -> Any: + r"""Samples a subgraph from a batch of input nodes.""" + input_data: EdgeSamplerInput = self.input_data[index] + out = self.link_sampler.sample_from_edges( + input_data, + negative_sampling_ratio=self.neg_sampling_ratio, ) + if self.filter_per_worker: # Execute `filter_fn` in the worker process + out = self.filter_fn(out) + + return out + def filter_fn( self, out: Union[SamplerOutput, HeteroSamplerOutput], ) -> Union[Data, HeteroData]: r"""Joins the sampled nodes with their corresponding features, - returning the resulting (Data or HeteroData) object to be used - downstream.""" + returning the resulting :class:`~torch_geometric.data.Data` or + :class:`~torch_geometric.data.HeteroData` object to be used downstream. + """ if isinstance(out, SamplerOutput): - edge_label_index, edge_label, edge_label_time = out.metadata data = filter_data(self.data, out.node, out.row, out.col, out.edge, self.link_sampler.edge_permutation) + data.batch = out.batch - data.edge_label_index = edge_label_index - data.edge_label = edge_label - data.edge_label_time = edge_label_time + data.edge_label_index = out.metadata[0] + data.edge_label = out.metadata[1] + data.edge_label_time = out.metadata[2] elif isinstance(out, HeteroSamplerOutput): - edge_label_index, edge_label, edge_label_time = out.metadata if isinstance(self.data, HeteroData): data = filter_hetero_data(self.data, out.node, out.row, out.col, out.edge, @@ -144,13 +142,11 @@ def filter_fn( data = filter_custom_store(*self.data, out.node, out.row, out.col, out.edge) - edge_type = self.input_type for key, batch in (out.batch or {}).items(): data[key].batch = batch - data[edge_type].edge_label_index = edge_label_index - data[edge_type].edge_label = edge_label - if edge_label_time is not None: - data[edge_type].edge_label_time = edge_label_time + data[self.edge_type].edge_label_index = out.metadata[0] + data[self.edge_type].edge_label = out.metadata[1] + data[self.edge_type].edge_label_time = out.metadata[2] else: raise TypeError(f"'{self.__class__.__name__}'' found invalid " @@ -158,61 +154,12 @@ def filter_fn( return data if self.transform is None else self.transform(data) - def collate_fn(self, index: EdgeSamplerInput) -> Any: - r"""Samples a subgraph from a batch of input nodes.""" - out = self.link_sampler.sample_from_edges( - index, - negative_sampling_ratio=self.neg_sampling_ratio, - ) - if self.filter_per_worker: - # We execute `filter_fn` in the worker process. - out = self.filter_fn(out) - return out - def _get_iterator(self) -> Iterator: if self.filter_per_worker: return super()._get_iterator() - # We execute `filter_fn` in the main process. + + # Execute `filter_fn` in the main process: return DataLoaderIterator(super()._get_iterator(), self.filter_fn) def __repr__(self) -> str: return f'{self.__class__.__name__}()' - - -############################################################################### - - -class Dataset(torch.utils.data.Dataset): - def __init__( - self, - edge_label_index: torch.Tensor, - edge_label: torch.Tensor, - edge_label_time: OptTensor = None, - ): - # NOTE see documentation of LinkLoader for details on these three - # input parameters: - self.edge_label_index = edge_label_index - self.edge_label = edge_label - self.edge_label_time = edge_label_time - - def __getitem__( - self, - idx: int, - ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: - if self.edge_label_time is None: - return ( - self.edge_label_index[0, idx], - self.edge_label_index[1, idx], - self.edge_label[idx], - ) - else: - return ( - self.edge_label_index[0, idx], - self.edge_label_index[1, idx], - self.edge_label[idx], - self.edge_label_time[idx], - ) - - def __len__(self) -> int: - return self.edge_label_index.size(1) diff --git a/torch_geometric/loader/link_neighbor_loader.py b/torch_geometric/loader/link_neighbor_loader.py index d37bf5353eeb..bb883ddd0de1 100644 --- a/torch_geometric/loader/link_neighbor_loader.py +++ b/torch_geometric/loader/link_neighbor_loader.py @@ -166,21 +166,16 @@ def __init__( neighbor_sampler: Optional[NeighborSampler] = None, **kwargs, ): - # Get input type: - # TODO(manan): this computation is required twice, once here and once - # in LinkLoader: + # TODO(manan): Avoid duplicated computation (here and in NodeLoader): edge_type, _ = get_edge_label_index(data, edge_label_index) - has_time_attr = time_attr is not None - has_edge_label_time = edge_label_time is not None - if has_edge_label_time != has_time_attr: + if (edge_label_time is not None) != (time_attr is not None): raise ValueError( - f"Received conflicting 'time_attr' and 'edge_label_time' " - f"arguments: 'time_attr' was " - f"{'set' if has_time_attr else 'not set'} and " - f"'edge_label_time' was " - f"{'set' if has_edge_label_time else 'not set'}. Please " - f"resolve these conflicting arguments.") + f"Received conflicting 'edge_label_time' and 'time_attr' " + f"arguments: 'edge_label_time' is " + f"{'set' if edge_label_time is not None else 'not set'} " + f"while 'input_time' is " + f"{'set' if time_attr is not None else 'not set'}.") if neighbor_sampler is None: neighbor_sampler = NeighborSampler( diff --git a/torch_geometric/loader/neighbor_loader.py b/torch_geometric/loader/neighbor_loader.py index 5d3b05457cda..94cd9b8a633c 100644 --- a/torch_geometric/loader/neighbor_loader.py +++ b/torch_geometric/loader/neighbor_loader.py @@ -6,7 +6,7 @@ from torch_geometric.loader.node_loader import NodeLoader from torch_geometric.loader.utils import get_input_nodes from torch_geometric.sampler import NeighborSampler -from torch_geometric.typing import InputNodes, NumNeighbors +from torch_geometric.typing import InputNodes, NumNeighbors, OptTensor class NeighborLoader(NodeLoader): @@ -122,6 +122,11 @@ class NeighborLoader(NodeLoader): If set to :obj:`None`, all nodes will be considered. In heterogeneous graphs, needs to be passed as a tuple that holds the node type and node indices. (default: :obj:`None`) + input_time (torch.Tensor, optional): Optional values to override the + timestamp for the input nodes given in :obj:`input_nodes`. If not + set, will use the timestamps in :obj:`time_attr` as default (if + present). The :obj:`time_attr` needs to be set for this to work. + (default: :obj:`None`) replace (bool, optional): If set to :obj:`True`, will sample with replacement. (default: :obj:`False`) directed (bool, optional): If set to :obj:`False`, will include all @@ -164,6 +169,7 @@ def __init__( data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], num_neighbors: NumNeighbors, input_nodes: InputNodes = None, + input_time: OptTensor = None, replace: bool = False, directed: bool = True, temporal_strategy: str = 'uniform', @@ -174,11 +180,14 @@ def __init__( neighbor_sampler: Optional[NeighborSampler] = None, **kwargs, ): - # Get input type: - # TODO(manan): this computation is repeated twice, once here and once - # in NodeLoader: + # TODO(manan): Avoid duplicated computation (here and in NodeLoader): node_type, _ = get_input_nodes(data, input_nodes) + if input_time is not None and time_attr is None: + raise ValueError("Received conflicting 'input_time' and " + "'time_attr' arguments: 'input_time' is set " + "while 'time_attr' is not set.") + if neighbor_sampler is None: neighbor_sampler = NeighborSampler( data, @@ -192,12 +201,11 @@ def __init__( share_memory=kwargs.get('num_workers', 0) > 0, ) - # A NeighborLoader is simply a NodeLoader that uses the NeighborSampler - # sampling implementation: super().__init__( data=data, node_sampler=neighbor_sampler, input_nodes=input_nodes, + input_time=input_time, transform=transform, filter_per_worker=filter_per_worker, **kwargs, diff --git a/torch_geometric/loader/node_loader.py b/torch_geometric/loader/node_loader.py index a0d94a000f97..580ad5a06d54 100644 --- a/torch_geometric/loader/node_loader.py +++ b/torch_geometric/loader/node_loader.py @@ -7,6 +7,7 @@ from torch_geometric.data.graph_store import GraphStore from torch_geometric.loader.base import DataLoaderIterator from torch_geometric.loader.utils import ( + InputData, filter_custom_store, filter_data, filter_hetero_data, @@ -18,7 +19,7 @@ NodeSamplerInput, SamplerOutput, ) -from torch_geometric.typing import InputNodes +from torch_geometric.typing import InputNodes, OptTensor class NodeLoader(torch.utils.data.DataLoader): @@ -43,6 +44,11 @@ class NodeLoader(torch.utils.data.DataLoader): If set to :obj:`None`, all nodes will be considered. In heterogeneous graphs, needs to be passed as a tuple that holds the node type and node indices. (default: :obj:`None`) + input_time (torch.Tensor, optional): Optional values to override the + timestamp for the input nodes given in :obj:`input_nodes`. If not + set, will use the timestamps in :obj:`time_attr` as default (if + present). The :obj:`time_attr` needs to be set for this to work. + (default: :obj:`None`) transform (Callable, optional): A function/transform that takes in a sampled mini-batch and returns a transformed version. (default: :obj:`None`) @@ -63,40 +69,43 @@ def __init__( data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], node_sampler: BaseSampler, input_nodes: InputNodes = None, + input_time: OptTensor = None, transform: Callable = None, filter_per_worker: bool = False, **kwargs, ): - # Remove for PyTorch Lightning: - if 'dataset' in kwargs: - del kwargs['dataset'] - if 'collate_fn' in kwargs: - del kwargs['collate_fn'] + # Get node type (or `None` for homogeneous graphs): + node_type, input_nodes = get_input_nodes(data, input_nodes) self.data = data - - # NOTE sampler is an attribute of 'DataLoader', so we use node_sampler - # here: + self.node_type = node_type self.node_sampler = node_sampler - - # Store additional arguments: - self.input_nodes = input_nodes + self.input_data = InputData(input_nodes, input_time) self.transform = transform self.filter_per_worker = filter_per_worker - # Get input type, or None for homogeneous graphs: - node_type, input_nodes = get_input_nodes(self.data, input_nodes) - self.input_type = node_type + iterator = range(input_nodes.size(0)) + super().__init__(iterator, collate_fn=self.collate_fn, **kwargs) + + def collate_fn(self, index: NodeSamplerInput) -> Any: + r"""Samples a subgraph from a batch of input nodes.""" + input_data: NodeSamplerInput = self.input_data[index] + + out = self.node_sampler.sample_from_nodes(input_data) + + if self.filter_per_worker: # Execute `filter_fn` in the worker process + out = self.filter_fn(out) - super().__init__(input_nodes, collate_fn=self.collate_fn, **kwargs) + return out def filter_fn( self, out: Union[SamplerOutput, HeteroSamplerOutput], ) -> Union[Data, HeteroData]: r"""Joins the sampled nodes with their corresponding features, - returning the resulting (Data or HeteroData) object to be used - downstream.""" + returning the resulting :class:`~torch_geometric.data.Data` or + :class:`~torch_geometric.data.HeteroData` object to be used downstream. + """ if isinstance(out, SamplerOutput): data = filter_data(self.data, out.node, out.row, out.col, out.edge, self.node_sampler.edge_permutation) @@ -122,21 +131,11 @@ def filter_fn( return data if self.transform is None else self.transform(data) - def collate_fn(self, index: NodeSamplerInput) -> Any: - r"""Samples a subgraph from a batch of input nodes.""" - if isinstance(index, (list, tuple)): - index = torch.tensor(index) - - out = self.node_sampler.sample_from_nodes(index) - if self.filter_per_worker: - # We execute `filter_fn` in the worker process. - out = self.filter_fn(out) - return out - def _get_iterator(self) -> Iterator: if self.filter_per_worker: return super()._get_iterator() - # We execute `filter_fn` in the main process. + + # Execute `filter_fn` in the main process: return DataLoaderIterator(super()._get_iterator(), self.filter_fn) def __repr__(self) -> str: diff --git a/torch_geometric/loader/utils.py b/torch_geometric/loader/utils.py index 7306ec7d507a..6c7662e28614 100644 --- a/torch_geometric/loader/utils.py +++ b/torch_geometric/loader/utils.py @@ -1,7 +1,7 @@ import copy import math from collections.abc import Sequence -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -20,6 +20,20 @@ ) +class InputData: + def __init__(self, *args): + self.args = args + + def __getitem__(self, index: Union[Tensor, List[int]]) -> Any: + if isinstance(index, (list, tuple)): + index = torch.tensor(index) + + outs = [] + for arg in self.args: + outs.append(arg[index] if arg is not None else None) + return tuple(outs) + + def index_select(value: FeatureTensorType, index: Tensor, dim: int = 0) -> Tensor: if isinstance(value, Tensor): @@ -196,14 +210,14 @@ def to_index(tensor): if isinstance(data, Data): if input_nodes is None: - return None, range(data.num_nodes) + return None, torch.arange(data.num_nodes) return None, to_index(input_nodes) elif isinstance(data, HeteroData): assert input_nodes is not None if isinstance(input_nodes, str): - return input_nodes, range(data[input_nodes].num_nodes) + return input_nodes, torch.arange(data[input_nodes].num_nodes) assert isinstance(input_nodes, (list, tuple)) assert len(input_nodes) == 2 @@ -211,7 +225,7 @@ def to_index(tensor): node_type, input_nodes = input_nodes if input_nodes is None: - return node_type, range(data[node_type].num_nodes) + return node_type, torch.arange(data[node_type].num_nodes) return node_type, to_index(input_nodes) else: # Tuple[FeatureStore, GraphStore] @@ -222,7 +236,7 @@ def to_index(tensor): return None, to_index(input_nodes) if isinstance(input_nodes, str): - return input_nodes, range( + return input_nodes, torch.arange( remote_backend_utils.num_nodes(feature_store, graph_store, input_nodes)) @@ -232,7 +246,7 @@ def to_index(tensor): node_type, input_nodes = input_nodes if input_nodes is None: - return node_type, range( + return node_type, torch.arange( remote_backend_utils.num_nodes(feature_store, graph_store, input_nodes)) return node_type, to_index(input_nodes) diff --git a/torch_geometric/sampler/base.py b/torch_geometric/sampler/base.py index a2081423ea32..08c4928bdf27 100644 --- a/torch_geometric/sampler/base.py +++ b/torch_geometric/sampler/base.py @@ -2,19 +2,21 @@ from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union -import torch +from torch import Tensor from torch_geometric.typing import EdgeType, NodeType, OptTensor -# An input to a node-based sampler is a tensor of node indices: -NodeSamplerInput = torch.Tensor +# An input to a node-based sampler consists of two tensors: +# * The node indices +# * The timestamps of the given node indices (optional) +NodeSamplerInput = Tuple[Tensor, OptTensor] # An input to an edge-based sampler consists of four tensors: # * The row of the edge index in COO format # * The column of the edge index in COO format # * The labels of the edges -# * (Optionally) the time attribute corresponding to the edge label -EdgeSamplerInput = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, OptTensor] +# * The time attribute corresponding to the edge label (optional) +EdgeSamplerInput = Tuple[Tensor, Tensor, Tensor, OptTensor] # A sampler output contains the following information. @@ -40,11 +42,11 @@ # There exist both homogeneous and heterogeneous versions. @dataclass class SamplerOutput: - node: torch.Tensor - row: torch.Tensor - col: torch.Tensor - edge: torch.Tensor - batch: Optional[torch.Tensor] = None + node: Tensor + row: Tensor + col: Tensor + edge: Tensor + batch: OptTensor = None # TODO(manan): refine this further; it does not currently define a proper # API for the expected output of a sampler. metadata: Optional[Any] = None @@ -52,11 +54,11 @@ class SamplerOutput: @dataclass class HeteroSamplerOutput: - node: Dict[NodeType, torch.Tensor] - row: Dict[EdgeType, torch.Tensor] - col: Dict[EdgeType, torch.Tensor] - edge: Dict[EdgeType, torch.Tensor] - batch: Optional[Dict[NodeType, torch.Tensor]] = None + node: Dict[NodeType, Tensor] + row: Dict[EdgeType, Tensor] + col: Dict[EdgeType, Tensor] + edge: Dict[EdgeType, Tensor] + batch: Optional[Dict[NodeType, Tensor]] = None # TODO(manan): refine this further; it does not currently define a proper # API for the expected output of a sampler. metadata: Optional[Any] = None diff --git a/torch_geometric/sampler/neighbor_sampler.py b/torch_geometric/sampler/neighbor_sampler.py index 8056800bf3fe..138ae1dbaca9 100644 --- a/torch_geometric/sampler/neighbor_sampler.py +++ b/torch_geometric/sampler/neighbor_sampler.py @@ -328,18 +328,16 @@ def sample_from_nodes( index: NodeSamplerInput, **kwargs, ) -> Union[SamplerOutput, HeteroSamplerOutput]: - if isinstance(index, (list, tuple)): - index = torch.tensor(index) + input_nodes, input_time = index - # Tuple[FeatureStore, GraphStore] currently only supports heterogeneous - # sampling: if self.data_cls == 'custom' or issubclass(self.data_cls, HeteroData): - output = self._sample(seed={self.input_type: index}) - output.metadata = index.numel() + output = self._sample(seed={self.input_type: input_nodes}, + seed_time_dict={self.input_type: input_time}) + output.metadata = input_nodes.numel() elif issubclass(self.data_cls, Data): - output = self._sample(seed=index) - output.metadata = index.numel() + output = self._sample(seed=input_nodes, seed_time=input_time) + output.metadata = input_nodes.numel() else: raise TypeError(f"'{self.__class__.__name__}'' found invalid " @@ -354,11 +352,9 @@ def sample_from_edges( index: EdgeSamplerInput, **kwargs, ) -> Union[SamplerOutput, HeteroSamplerOutput]: + row, col, edge_label, edge_label_time = index + edge_label_index = torch.stack([row, col], dim=0) negative_sampling_ratio = kwargs.get('negative_sampling_ratio', 0.0) - query = [torch.stack(s, dim=0) for s in zip(*index)] - edge_label_index = torch.stack(query[:2], dim=0) - edge_label = query[2] - edge_label_time = query[3] if len(query) == 4 else None out = add_negative_samples(edge_label_index, edge_label, edge_label_time, self.num_src_nodes, From a476a14cd95b6e6293d25bf3abbf91dafd83684c Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 18 Oct 2022 11:39:28 +0000 Subject: [PATCH 2/6] update --- test/loader/test_neighbor_loader.py | 2 +- torch_geometric/data/lightning_datamodule.py | 9 ++------- torch_geometric/loader/node_loader.py | 2 +- torch_geometric/loader/utils.py | 2 ++ torch_geometric/sampler/neighbor_sampler.py | 9 +++++---- 5 files changed, 11 insertions(+), 13 deletions(-) diff --git a/test/loader/test_neighbor_loader.py b/test/loader/test_neighbor_loader.py index 3668a4ac9e1b..af9421c1e050 100644 --- a/test/loader/test_neighbor_loader.py +++ b/test/loader/test_neighbor_loader.py @@ -498,7 +498,7 @@ def test_pyg_lib_heterogeneous_neighbor_loader(): 'author__to__paper': [-1, -1], } - sample = torch.ops.pyg.hetero_neighbor_sample_cpu + sample = torch.ops.pyg.hetero_neighbor_sample out1 = sample(node_types, edge_types, colptr_dict, row_dict, seed_dict, num_neighbors_dict, None, None, True, False, True, False, "uniform", True) diff --git a/torch_geometric/data/lightning_datamodule.py b/torch_geometric/data/lightning_datamodule.py index 225c17577fa0..ffdf369e948e 100644 --- a/torch_geometric/data/lightning_datamodule.py +++ b/torch_geometric/data/lightning_datamodule.py @@ -8,13 +8,9 @@ from torch_geometric.data import Data, Dataset, HeteroData from torch_geometric.data.feature_store import FeatureStore from torch_geometric.data.graph_store import GraphStore -from torch_geometric.loader import NeighborLoader +from torch_geometric.loader import LinkNeighborLoader, NeighborLoader from torch_geometric.loader.dataloader import DataLoader -from torch_geometric.loader.link_neighbor_loader import ( - LinkNeighborLoader, - get_edge_label_index, -) -from torch_geometric.loader.utils import get_input_nodes +from torch_geometric.loader.utils import get_edge_label_index, get_input_nodes from torch_geometric.sampler import NeighborSampler from torch_geometric.typing import InputEdges, InputNodes @@ -311,7 +307,6 @@ def __init__( directed=kwargs.get('directed', True), input_type=get_input_nodes(data, input_train_nodes)[0], time_attr=kwargs.get('time_attr', None), - seed_time=kwargs.get('seed_time', None), is_sorted=kwargs.get('is_sorted', False), share_memory=num_workers > 0, ) diff --git a/torch_geometric/loader/node_loader.py b/torch_geometric/loader/node_loader.py index 580ad5a06d54..8e6c42d8ac10 100644 --- a/torch_geometric/loader/node_loader.py +++ b/torch_geometric/loader/node_loader.py @@ -123,7 +123,7 @@ def filter_fn( for key, batch in (out.batch or {}).items(): data[key].batch = batch - data[self.input_type].batch_size = out.metadata + data[self.node_type].batch_size = out.metadata else: raise TypeError(f"'{self.__class__.__name__}'' found invalid " diff --git a/torch_geometric/loader/utils.py b/torch_geometric/loader/utils.py index 6c7662e28614..1cac953c57e8 100644 --- a/torch_geometric/loader/utils.py +++ b/torch_geometric/loader/utils.py @@ -206,6 +206,8 @@ def get_input_nodes( def to_index(tensor): if isinstance(tensor, Tensor) and tensor.dtype == torch.bool: return tensor.nonzero(as_tuple=False).view(-1) + if not isinstance(tensor, Tensor): + return torch.tensor(tensor, dtype=torch.long) return tensor if isinstance(data, Data): diff --git a/torch_geometric/sampler/neighbor_sampler.py b/torch_geometric/sampler/neighbor_sampler.py index 138ae1dbaca9..556eedf4689b 100644 --- a/torch_geometric/sampler/neighbor_sampler.py +++ b/torch_geometric/sampler/neighbor_sampler.py @@ -220,9 +220,7 @@ def _sample( Note that the 'metadata' field of the output is not filled; it is the job of the caller to appropriately fill out this field for downstream loaders.""" - - # TODO(manan): remote backends only support heterogeneous graphs for - # now: + # TODO(manan): remote backends only support heterogeneous graphs: if self.data_cls == 'custom' or issubclass(self.data_cls, HeteroData): if _WITH_PYG_LIB: # TODO (matthias) Add `disjoint` option to `NeighborSampler` @@ -331,8 +329,11 @@ def sample_from_nodes( input_nodes, input_time = index if self.data_cls == 'custom' or issubclass(self.data_cls, HeteroData): + seed_time_dict = None + if input_time is not None: + seed_time_dict = {self.input_type: input_time} output = self._sample(seed={self.input_type: input_nodes}, - seed_time_dict={self.input_type: input_time}) + seed_time_dict=seed_time_dict) output.metadata = input_nodes.numel() elif issubclass(self.data_cls, Data): From b625daa1d775a2cc529bf173bddc7b86f3968110 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 18 Oct 2022 11:42:04 +0000 Subject: [PATCH 3/6] update --- torch_geometric/loader/link_loader.py | 6 ++++++ torch_geometric/loader/node_loader.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/torch_geometric/loader/link_loader.py b/torch_geometric/loader/link_loader.py index 1399d9486928..f91d21bf9ed9 100644 --- a/torch_geometric/loader/link_loader.py +++ b/torch_geometric/loader/link_loader.py @@ -84,6 +84,12 @@ def __init__( filter_per_worker: bool = False, **kwargs, ): + # Remove for PyTorch Lightning: + if 'dataset' in kwargs: + del kwargs['dataset'] + if 'collate_fn' in kwargs: + del kwargs['collate_fn'] + # Get edge type (or `None` for homogeneous graphs): edge_type, edge_label_index = get_edge_label_index( data, edge_label_index) diff --git a/torch_geometric/loader/node_loader.py b/torch_geometric/loader/node_loader.py index 8e6c42d8ac10..275580cf1bef 100644 --- a/torch_geometric/loader/node_loader.py +++ b/torch_geometric/loader/node_loader.py @@ -74,6 +74,12 @@ def __init__( filter_per_worker: bool = False, **kwargs, ): + # Remove for PyTorch Lightning: + if 'dataset' in kwargs: + del kwargs['dataset'] + if 'collate_fn' in kwargs: + del kwargs['collate_fn'] + # Get node type (or `None` for homogeneous graphs): node_type, input_nodes = get_input_nodes(data, input_nodes) From 2a7c34a82a62d7175032352c28fa955d41e0dc05 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 18 Oct 2022 11:44:34 +0000 Subject: [PATCH 4/6] update --- CHANGELOG.md | 1 + torch_geometric/loader/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ab2417fdb89f..0b036bb3819d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.2.0] - 2022-MM-DD ### Added +- Added support for `input_time` in `NeighborLoader` ([#5763](https://github.com/pyg-team/pytorch_geometric/pull/5763)) - Added `disjoint` mode for temporal `LinkNeighborLoader` ([#5717](https://github.com/pyg-team/pytorch_geometric/pull/5717)) - Added `HeteroData` support for `transforms.Constant` ([#5700](https://github.com/pyg-team/pytorch_geometric/pull/5700)) - Added `np.memmap` support in `NeighborLoader` ([#5696](https://github.com/pyg-team/pytorch_geometric/pull/5696)) diff --git a/torch_geometric/loader/utils.py b/torch_geometric/loader/utils.py index 1cac953c57e8..cfa7a2eb1cad 100644 --- a/torch_geometric/loader/utils.py +++ b/torch_geometric/loader/utils.py @@ -25,8 +25,8 @@ def __init__(self, *args): self.args = args def __getitem__(self, index: Union[Tensor, List[int]]) -> Any: - if isinstance(index, (list, tuple)): - index = torch.tensor(index) + if not isinstance(index, Tensor): + index = torch.tensor(index, dtype=torch.long) outs = [] for arg in self.args: From 8874b73e3e5f628c3977aa3a3acf8ec6b7b3f2b4 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 18 Oct 2022 11:47:52 +0000 Subject: [PATCH 5/6] update --- torch_geometric/loader/hgt_loader.py | 6 ++++-- torch_geometric/sampler/hgt_sampler.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/torch_geometric/loader/hgt_loader.py b/torch_geometric/loader/hgt_loader.py index e41429e21b89..b078ae7c935f 100644 --- a/torch_geometric/loader/hgt_loader.py +++ b/torch_geometric/loader/hgt_loader.py @@ -104,16 +104,18 @@ def __init__( **kwargs, ): node_type, _ = get_input_nodes(data, input_nodes) - node_sampler = HGTSampler( + + hgt_sampler = HGTSampler( data, num_samples=num_samples, input_type=node_type, is_sorted=is_sorted, share_memory=kwargs.get('num_workers', 0) > 0, ) + super().__init__( data=data, - node_sampler=node_sampler, + node_sampler=hgt_sampler, input_nodes=input_nodes, transform=transform, filter_per_worker=filter_per_worker, diff --git a/torch_geometric/sampler/hgt_sampler.py b/torch_geometric/sampler/hgt_sampler.py index 661a2d517c3d..210899933b9d 100644 --- a/torch_geometric/sampler/hgt_sampler.py +++ b/torch_geometric/sampler/hgt_sampler.py @@ -60,7 +60,8 @@ def sample_from_nodes( index: NodeSamplerInput, **kwargs, ) -> HeteroSamplerOutput: - input_node_dict = {self.input_type: torch.tensor(index)} + input_nodes, _ = index + input_node_dict = {self.input_type: input_nodes} sample_fn = torch.ops.torch_sparse.hgt_sample out = sample_fn( self.colptr_dict, @@ -76,7 +77,7 @@ def sample_from_nodes( col=remap_keys(col, self.to_edge_type), edge=remap_keys(edge, self.to_edge_type), batch=batch, - metadata=len(index), + metadata=input_nodes.size(0), ) def sample_from_edges( From d22ae19a36ebb727b015849f5abc633820aea37e Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 18 Oct 2022 12:08:24 +0000 Subject: [PATCH 6/6] update --- test/loader/test_hgt_loader.py | 3 ++- test/loader/test_link_neighbor_loader.py | 12 +++++++----- test/loader/test_neighbor_loader.py | 8 ++++---- torch_geometric/loader/link_loader.py | 14 ++++++++------ torch_geometric/loader/node_loader.py | 6 ++++-- torch_geometric/loader/utils.py | 2 +- torch_geometric/sampler/base.py | 6 ++++-- torch_geometric/sampler/hgt_sampler.py | 4 ++-- torch_geometric/sampler/neighbor_sampler.py | 14 ++++++++------ 9 files changed, 40 insertions(+), 29 deletions(-) diff --git a/test/loader/test_hgt_loader.py b/test/loader/test_hgt_loader.py index 50afe304a392..bf477eb81d0d 100644 --- a/test/loader/test_hgt_loader.py +++ b/test/loader/test_hgt_loader.py @@ -60,8 +60,9 @@ def test_hgt_loader(): assert set(batch.node_types) == {'paper', 'author'} assert set(batch.edge_types) == set(data.edge_types) - assert len(batch['paper']) == 2 + assert len(batch['paper']) == 3 assert batch['paper'].x.size() == (40, ) # 20 + 4 * 5 + assert batch['paper'].input_nodes.numel() == batch_size assert batch['paper'].batch_size == batch_size assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100 diff --git a/test/loader/test_link_neighbor_loader.py b/test/loader/test_link_neighbor_loader.py index e09649c09847..60751e6a24d5 100644 --- a/test/loader/test_link_neighbor_loader.py +++ b/test/loader/test_link_neighbor_loader.py @@ -51,9 +51,10 @@ def test_homogeneous_link_neighbor_loader(directed, neg_sampling_ratio): for batch in loader: assert isinstance(batch, Data) - assert len(batch) == 5 + assert len(batch) == 6 assert batch.x.size(0) <= 100 assert batch.x.min() >= 0 and batch.x.max() < 100 + assert batch.input_links.numel() == 20 assert batch.edge_index.min() >= 0 assert batch.edge_index.max() < batch.num_nodes assert batch.edge_attr.min() >= 0 @@ -110,7 +111,7 @@ def test_heterogeneous_link_neighbor_loader(directed, neg_sampling_ratio): for batch in loader: assert isinstance(batch, HeteroData) - assert len(batch) == 5 + assert len(batch) == 6 if neg_sampling_ratio == 0.0: # Assert only positive samples are present in the original graph: assert batch['paper', 'author'].edge_label.sum() == 0 @@ -120,7 +121,6 @@ def test_heterogeneous_link_neighbor_loader(directed, neg_sampling_ratio): assert len(edge_index | edge_label_index) == len(edge_index) else: - assert batch['paper', 'author'].edge_label_index.size(1) == 40 assert torch.all(batch['paper', 'author'].edge_label[:20] == 1) assert torch.all(batch['paper', 'author'].edge_label[20:] == 0) @@ -312,7 +312,8 @@ def test_homogeneous_link_neighbor_loader_no_edges(): for batch in loader: assert isinstance(batch, Data) - assert len(batch) == 3 + assert len(batch) == 4 + assert batch.input_links.numel() == 20 assert batch.num_nodes <= 40 assert batch.edge_label_index.size(1) == 20 assert batch.num_nodes == batch.edge_label_index.unique().numel() @@ -328,8 +329,9 @@ def test_heterogeneous_link_neighbor_loader_no_edges(): for batch in loader: assert isinstance(batch, HeteroData) - assert len(batch) == 3 + assert len(batch) == 4 assert batch['paper'].num_nodes <= 40 + assert batch['paper', 'paper'].input_links.numel() == 20 assert batch['paper', 'paper'].edge_label_index.size(1) == 20 assert batch['paper'].num_nodes == batch[ 'paper', 'paper'].edge_label_index.unique().numel() diff --git a/test/loader/test_neighbor_loader.py b/test/loader/test_neighbor_loader.py index af9421c1e050..aff809e6188f 100644 --- a/test/loader/test_neighbor_loader.py +++ b/test/loader/test_neighbor_loader.py @@ -48,10 +48,9 @@ def test_homogeneous_neighbor_loader(directed): for batch in loader: assert isinstance(batch, Data) - - assert len(batch) == 4 + assert len(batch) == 5 assert batch.x.size(0) <= 100 - assert batch.batch_size == 20 + assert batch.input_nodes.numel() == batch.batch_size == 20 assert batch.x.min() >= 0 and batch.x.max() < 100 assert batch.edge_index.min() >= 0 assert batch.edge_index.max() < batch.num_nodes @@ -118,8 +117,9 @@ def test_heterogeneous_neighbor_loader(directed): # Test node type selection: assert set(batch.node_types) == {'paper', 'author'} - assert len(batch['paper']) == 2 + assert len(batch['paper']) == 3 assert batch['paper'].x.size(0) <= 100 + assert batch['paper'].input_nodes.numel() == batch_size assert batch['paper'].batch_size == batch_size assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100 diff --git a/torch_geometric/loader/link_loader.py b/torch_geometric/loader/link_loader.py index f91d21bf9ed9..cf5088870627 100644 --- a/torch_geometric/loader/link_loader.py +++ b/torch_geometric/loader/link_loader.py @@ -135,9 +135,10 @@ def filter_fn( self.link_sampler.edge_permutation) data.batch = out.batch - data.edge_label_index = out.metadata[0] - data.edge_label = out.metadata[1] - data.edge_label_time = out.metadata[2] + data.input_links = out.metadata[0] + data.edge_label_index = out.metadata[1] + data.edge_label = out.metadata[2] + data.edge_label_time = out.metadata[3] elif isinstance(out, HeteroSamplerOutput): if isinstance(self.data, HeteroData): @@ -150,9 +151,10 @@ def filter_fn( for key, batch in (out.batch or {}).items(): data[key].batch = batch - data[self.edge_type].edge_label_index = out.metadata[0] - data[self.edge_type].edge_label = out.metadata[1] - data[self.edge_type].edge_label_time = out.metadata[2] + data[self.edge_type].input_links = out.metadata[0] + data[self.edge_type].edge_label_index = out.metadata[1] + data[self.edge_type].edge_label = out.metadata[2] + data[self.edge_type].edge_label_time = out.metadata[3] else: raise TypeError(f"'{self.__class__.__name__}'' found invalid " diff --git a/torch_geometric/loader/node_loader.py b/torch_geometric/loader/node_loader.py index 275580cf1bef..e0b2916afa06 100644 --- a/torch_geometric/loader/node_loader.py +++ b/torch_geometric/loader/node_loader.py @@ -116,7 +116,8 @@ def filter_fn( data = filter_data(self.data, out.node, out.row, out.col, out.edge, self.node_sampler.edge_permutation) data.batch = out.batch - data.batch_size = out.metadata + data.input_nodes = out.metadata + data.batch_size = out.metadata.size(0) elif isinstance(out, HeteroSamplerOutput): if isinstance(self.data, HeteroData): @@ -129,7 +130,8 @@ def filter_fn( for key, batch in (out.batch or {}).items(): data[key].batch = batch - data[self.node_type].batch_size = out.metadata + data[self.node_type].input_nodes = out.metadata + data[self.node_type].batch_size = out.metadata.size(0) else: raise TypeError(f"'{self.__class__.__name__}'' found invalid " diff --git a/torch_geometric/loader/utils.py b/torch_geometric/loader/utils.py index cfa7a2eb1cad..bc21424f4756 100644 --- a/torch_geometric/loader/utils.py +++ b/torch_geometric/loader/utils.py @@ -28,7 +28,7 @@ def __getitem__(self, index: Union[Tensor, List[int]]) -> Any: if not isinstance(index, Tensor): index = torch.tensor(index, dtype=torch.long) - outs = [] + outs = [index] for arg in self.args: outs.append(arg[index] if arg is not None else None) return tuple(outs) diff --git a/torch_geometric/sampler/base.py b/torch_geometric/sampler/base.py index 08c4928bdf27..e5f1d02e9cc7 100644 --- a/torch_geometric/sampler/base.py +++ b/torch_geometric/sampler/base.py @@ -7,16 +7,18 @@ from torch_geometric.typing import EdgeType, NodeType, OptTensor # An input to a node-based sampler consists of two tensors: +# * The example indices # * The node indices # * The timestamps of the given node indices (optional) -NodeSamplerInput = Tuple[Tensor, OptTensor] +NodeSamplerInput = Tuple[Tensor, Tensor, OptTensor] # An input to an edge-based sampler consists of four tensors: +# * The example indices # * The row of the edge index in COO format # * The column of the edge index in COO format # * The labels of the edges # * The time attribute corresponding to the edge label (optional) -EdgeSamplerInput = Tuple[Tensor, Tensor, Tensor, OptTensor] +EdgeSamplerInput = Tuple[Tensor, Tensor, Tensor, Tensor, OptTensor] # A sampler output contains the following information. diff --git a/torch_geometric/sampler/hgt_sampler.py b/torch_geometric/sampler/hgt_sampler.py index 210899933b9d..e72b669bc345 100644 --- a/torch_geometric/sampler/hgt_sampler.py +++ b/torch_geometric/sampler/hgt_sampler.py @@ -60,7 +60,7 @@ def sample_from_nodes( index: NodeSamplerInput, **kwargs, ) -> HeteroSamplerOutput: - input_nodes, _ = index + index, input_nodes, _ = index input_node_dict = {self.input_type: input_nodes} sample_fn = torch.ops.torch_sparse.hgt_sample out = sample_fn( @@ -77,7 +77,7 @@ def sample_from_nodes( col=remap_keys(col, self.to_edge_type), edge=remap_keys(edge, self.to_edge_type), batch=batch, - metadata=input_nodes.size(0), + metadata=index, ) def sample_from_edges( diff --git a/torch_geometric/sampler/neighbor_sampler.py b/torch_geometric/sampler/neighbor_sampler.py index 556eedf4689b..9b904a8af336 100644 --- a/torch_geometric/sampler/neighbor_sampler.py +++ b/torch_geometric/sampler/neighbor_sampler.py @@ -326,7 +326,7 @@ def sample_from_nodes( index: NodeSamplerInput, **kwargs, ) -> Union[SamplerOutput, HeteroSamplerOutput]: - input_nodes, input_time = index + index, input_nodes, input_time = index if self.data_cls == 'custom' or issubclass(self.data_cls, HeteroData): seed_time_dict = None @@ -334,11 +334,11 @@ def sample_from_nodes( seed_time_dict = {self.input_type: input_time} output = self._sample(seed={self.input_type: input_nodes}, seed_time_dict=seed_time_dict) - output.metadata = input_nodes.numel() + output.metadata = index elif issubclass(self.data_cls, Data): output = self._sample(seed=input_nodes, seed_time=input_time) - output.metadata = input_nodes.numel() + output.metadata = index else: raise TypeError(f"'{self.__class__.__name__}'' found invalid " @@ -353,7 +353,7 @@ def sample_from_edges( index: EdgeSamplerInput, **kwargs, ) -> Union[SamplerOutput, HeteroSamplerOutput]: - row, col, edge_label, edge_label_time = index + index, row, col, edge_label, edge_label_time = index edge_label_index = torch.stack([row, col], dim=0) negative_sampling_ratio = kwargs.get('negative_sampling_ratio', 0.0) @@ -421,7 +421,8 @@ def sample_from_edges( for key, batch in output.batch.items(): output.batch[key] = batch % num_seed_edges - output.metadata = (edge_label_index, edge_label, edge_label_time) + output.metadata = (index, edge_label_index, edge_label, + edge_label_time) elif issubclass(self.data_cls, Data): if self.disjoint_sampling: @@ -441,7 +442,8 @@ def sample_from_edges( if self.disjoint_sampling: output.batch = output.batch % num_seed_edges - output.metadata = (edge_label_index, edge_label, edge_label_time) + output.metadata = (index, edge_label_index, edge_label, + edge_label_time) else: raise TypeError(f"'{self.__class__.__name__}'' found invalid "