Skip to content

Commit

Permalink
NumNeighbors integration (#6505)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Jan 24, 2023
1 parent b10b465 commit a7a5fd2
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 74 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added `PGMExplainer` to `torch_geometric.contrib` ([#6149](https://github.com/pyg-team/pytorch_geometric/pull/6149))
- Added a `NumNeighbors` helper class for specifying the number of neighbors when sampling ([#6501](https://github.com/pyg-team/pytorch_geometric/pull/6501))
- Added a `NumNeighbors` helper class for specifying the number of neighbors when sampling ([#6501](https://github.com/pyg-team/pytorch_geometric/pull/6501), [#6505](https://github.com/pyg-team/pytorch_geometric/pull/6505))
- Added caching to `is_node_attr()` and `is_edge_attr()` calls ([#6492](https://github.com/pyg-team/pytorch_geometric/pull/6492))
- Added `ToHeteroLinear` and `ToHeteroMessagePassing` modules to accelerate `to_hetero` functionality ([#5992](https://github.com/pyg-team/pytorch_geometric/pull/5992), [#6456](https://github.com/pyg-team/pytorch_geometric/pull/6456))
- Added `GraphMaskExplainer` ([#6284](https://github.com/pyg-team/pytorch_geometric/pull/6284))
Expand Down
4 changes: 2 additions & 2 deletions test/data/lightning/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,9 +404,9 @@ def test_eval_loader_kwargs(get_dataset):
)

assert datamodule.loader_kwargs['batch_size'] == 32
assert datamodule.graph_sampler.num_neighbors == [5]
assert datamodule.graph_sampler.num_neighbors.values == [5]
assert datamodule.eval_loader_kwargs['batch_size'] == 64
assert datamodule.eval_graph_sampler.num_neighbors == [-1]
assert datamodule.eval_graph_sampler.num_neighbors.values == [-1]

train_loader = datamodule.train_dataloader()
assert math.ceil(int(data.train_mask.sum()) / 32) == len(train_loader)
Expand Down
9 changes: 5 additions & 4 deletions test/loader/test_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,19 @@ def test_hetero_neighbor_loader_basic(directed, dtype):

batch_size = 20

with pytest.raises(ValueError, match="to have 2 entries"):
with pytest.raises(ValueError, match="hops must be the same across all"):
loader = NeighborLoader(
data,
num_neighbors={
('paper', 'paper'): [-1],
('paper', 'author'): [-1, -1],
('author', 'paper'): [-1, -1],
('paper', 'to', 'paper'): [-1],
('paper', 'to', 'author'): [-1, -1],
('author', 'to', 'paper'): [-1, -1],
},
input_nodes='paper',
batch_size=batch_size,
directed=directed,
)
next(iter(loader))

loader = NeighborLoader(
data,
Expand Down
12 changes: 0 additions & 12 deletions test/sampler/test_sampler_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,3 @@ def test_heterogeneous_num_neighbors_dict_and_default():
assert values == {'A__B': [25, 10], 'B__A': [-1, -1]}

assert num_neighbors.num_hops == 2


def test_num_neighbors_cast():
num_neighbors = NumNeighbors.cast([25, 10])
assert isinstance(num_neighbors, NumNeighbors)
assert num_neighbors.values == [25, 10]
assert num_neighbors.default is None

num_neighbors = NumNeighbors.cast({('A', 'B'): [25, 10]}, [-1, -1])
assert isinstance(num_neighbors, NumNeighbors)
assert num_neighbors.values == {('A', 'B'): [25, 10]}
assert num_neighbors.default == [-1, -1]
6 changes: 3 additions & 3 deletions torch_geometric/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData
from torch_geometric.loader.link_loader import LinkLoader
from torch_geometric.sampler import NegativeSampling, NeighborSampler
from torch_geometric.typing import InputEdges, NumNeighbors, OptTensor
from torch_geometric.typing import EdgeType, InputEdges, OptTensor


class LinkNeighborLoader(LinkLoader):
Expand Down Expand Up @@ -173,7 +173,7 @@ class LinkNeighborLoader(LinkLoader):
def __init__(
self,
data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],
num_neighbors: NumNeighbors,
num_neighbors: Union[List[int], Dict[EdgeType, List[int]]],
edge_label_index: InputEdges = None,
edge_label: OptTensor = None,
edge_label_time: OptTensor = None,
Expand Down
6 changes: 3 additions & 3 deletions torch_geometric/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData
from torch_geometric.loader.node_loader import NodeLoader
from torch_geometric.sampler import NeighborSampler
from torch_geometric.typing import InputNodes, NumNeighbors, OptTensor
from torch_geometric.typing import EdgeType, InputNodes, OptTensor


class NeighborLoader(NodeLoader):
Expand Down Expand Up @@ -175,7 +175,7 @@ class NeighborLoader(NodeLoader):
def __init__(
self,
data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],
num_neighbors: NumNeighbors,
num_neighbors: Union[List[int], Dict[EdgeType, List[int]]],
input_nodes: InputNodes = None,
input_time: OptTensor = None,
replace: bool = False,
Expand Down
6 changes: 5 additions & 1 deletion torch_geometric/sampler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class HeteroSamplerOutput(CastMixin):


@dataclass(frozen=True)
class NumNeighbors(CastMixin):
class NumNeighbors:
r"""The number of neighbors to sample in a homogeneous or heterogeneous
graph. In heterogeneous graphs, may also take in a dictionary denoting
the amount of neighbors to sample for individual edge types.
Expand Down Expand Up @@ -278,6 +278,10 @@ def num_hops(self) -> int:
self.__dict__['_num_hops'] = num_hops
return num_hops

def __len__(self) -> int:
r"""Returns the number of hops."""
return self.num_hops


class NegativeSamplingMode(Enum):
# 'binary': Randomly sample negative edges in the graph.
Expand Down
62 changes: 16 additions & 46 deletions torch_geometric/sampler/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
NodeSamplerInput,
SamplerOutput,
)
from torch_geometric.sampler.base import DataType
from torch_geometric.sampler.base import DataType, NumNeighbors
from torch_geometric.sampler.utils import remap_keys, to_csc, to_hetero_csc
from torch_geometric.typing import EdgeType, NodeType, NumNeighbors, OptTensor
from torch_geometric.typing import EdgeType, NodeType, OptTensor

NumNeighborsType = Union[NumNeighbors, List[int], Dict[EdgeType, List[int]]]


class NeighborSampler(BaseSampler):
Expand All @@ -33,7 +35,7 @@ class NeighborSampler(BaseSampler):
def __init__(
self,
data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],
num_neighbors: NumNeighbors,
num_neighbors: NumNeighborsType,
replace: bool = False,
directed: bool = True,
disjoint: bool = False,
Expand Down Expand Up @@ -133,47 +135,15 @@ def __init__(
self.temporal_strategy = temporal_strategy

@property
def num_neighbors(self) -> Union[List[int], Dict[EdgeType, List[int]]]:
def num_neighbors(self) -> NumNeighbors:
return self._num_neighbors

@num_neighbors.setter
def num_neighbors(
self,
num_neighbors: Union[List[int], Dict[EdgeType, List[int]]],
):
if self.data_type == DataType.homogeneous:
if not isinstance(num_neighbors, (list, tuple)):
raise ValueError(f"Expected 'num_neighbors' to be a list or a "
f"tuple (got {type(self.num_neighbors)})")
self._num_hops = len(num_neighbors)
self._num_neighbors = list(num_neighbors)
return

assert self.data_type in [DataType.heterogeneous, DataType.remote]

if isinstance(num_neighbors, (list, tuple)):
num_neighbors = {k: list(num_neighbors) for k in self.edge_types}

if not isinstance(num_neighbors, dict):
raise ValueError(f"Expected 'num_neighbors' to be a dictionary "
f"(got '{type(self.num_neighbors)}')")

for edge_type, values in num_neighbors.items():
if not isinstance(values, (list, tuple)):
raise ValueError(f"Expected 'num_neighbors' to be a list or a "
f"tuple (got {type(values)})")
num_neighbors[edge_type] = list(values)

# Add at least one element to the list to ensure `max` is well-defined:
self._num_hops = max([0] + [len(v) for v in num_neighbors.values()])

for edge_type, values in num_neighbors.items():
if len(values) != self._num_hops:
raise ValueError(
f"Expected the edge type {edge_type} to have "
f"{self._num_hops} entries (got {len(values)})")

self._num_neighbors = remap_keys(num_neighbors, self.to_rel_type)
def num_neighbors(self, num_neighbors: NumNeighborsType):
if isinstance(num_neighbors, NumNeighbors):
self._num_neighbors = num_neighbors
else:
self._num_neighbors = NumNeighbors(num_neighbors)

@property
def is_temporal(self) -> bool:
Expand Down Expand Up @@ -234,7 +204,7 @@ def _sample(
self.colptr_dict,
self.row_dict,
seed,
self.num_neighbors,
self.num_neighbors.get_mapped_values(self.edge_types),
self.node_time,
seed_time,
True, # csc
Expand Down Expand Up @@ -263,8 +233,8 @@ def _sample(
self.colptr_dict,
self.row_dict,
seed, # seed_dict
self.num_neighbors,
self._num_hops,
self.num_neighbors.get_mapped_values(self.edge_types),
self.num_neighbors.num_hops,
self.replace,
self.directed,
)
Expand All @@ -286,7 +256,7 @@ def _sample(
self.colptr,
self.row,
seed.to(self.colptr.dtype), # seed
self.num_neighbors,
self.num_neighbors.get_mapped_values(),
self.node_time,
seed_time,
True, # csc
Expand All @@ -311,7 +281,7 @@ def _sample(
self.colptr,
self.row,
seed, # seed
self.num_neighbors,
self.num_neighbors.get_mapped_values(),
self.replace,
self.directed,
)
Expand Down
3 changes: 1 addition & 2 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np
from torch import Tensor
Expand Down Expand Up @@ -59,4 +59,3 @@ def __init__(self, *args, **kwargs):

InputNodes = Union[OptTensor, NodeType, Tuple[NodeType, OptTensor]]
InputEdges = Union[OptTensor, EdgeType, Tuple[EdgeType, OptTensor]]
NumNeighbors = Union[List[int], Dict[EdgeType, List[int]]]

0 comments on commit a7a5fd2

Please sign in to comment.