Skip to content

Commit

Permalink
Expose disjoint sampling option to NeighborLoader and `LinkNeighb…
Browse files Browse the repository at this point in the history
…orLoader` (pyg-team#5775)

see discussion:
pyg-team#5660

Co-authored-by: Be%CC%81ni Balazs Egressy Beni.Balazs.Egressy@ibm.com <begressy@cccxl011.pok.ibm.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored and JakubPietrakIntel committed Nov 25, 2022
1 parent ea4716e commit 6a98f56
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.2.0] - 2022-MM-DD
### Added
- Added `disjoint` argument to `NeighborLoader` and `LinkNeighborLoader` ([#5775](https://github.com/pyg-team/pytorch_geometric/pull/5775))
- Added support for `input_time` in `NeighborLoader` ([#5763](https://github.com/pyg-team/pytorch_geometric/pull/5763))
- Added `disjoint` mode for temporal `LinkNeighborLoader` ([#5717](https://github.com/pyg-team/pytorch_geometric/pull/5717))
- Added `HeteroData` support for `transforms.Constant` ([#5700](https://github.com/pyg-team/pytorch_geometric/pull/5700))
Expand Down
8 changes: 8 additions & 0 deletions torch_geometric/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ class LinkNeighborLoader(LinkLoader):
replacement. (default: :obj:`False`)
directed (bool, optional): If set to :obj:`False`, will include all
edges between all sampled nodes. (default: :obj:`True`)
disjoint (bool, optional): If set to :obj: `True`, each seed node will
create its own disjoint subgraph.
If set to :obj:`True`, mini-batch outputs will have a :obj:`batch`
vector holding the mapping of nodes to their respective subgraph.
Will get automatically set to :obj:`True` in case of temporal
sampling. (default: :obj:`False`)
temporal_strategy (string, optional): The sampling strategy when using
temporal sampling (:obj:`"uniform"`, :obj:`"last"`).
If set to :obj:`"uniform"`, will sample uniformly across neighbors
Expand Down Expand Up @@ -157,6 +163,7 @@ def __init__(
edge_label_time: OptTensor = None,
replace: bool = False,
directed: bool = True,
disjoint: bool = False,
temporal_strategy: str = 'uniform',
neg_sampling_ratio: float = 0.0,
time_attr: Optional[str] = None,
Expand All @@ -183,6 +190,7 @@ def __init__(
num_neighbors=num_neighbors,
replace=replace,
directed=directed,
disjoint=disjoint,
temporal_strategy=temporal_strategy,
input_type=edge_type,
time_attr=time_attr,
Expand Down
8 changes: 8 additions & 0 deletions torch_geometric/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ class NeighborLoader(NodeLoader):
replacement. (default: :obj:`False`)
directed (bool, optional): If set to :obj:`False`, will include all
edges between all sampled nodes. (default: :obj:`True`)
disjoint (bool, optional): If set to :obj: `True`, each seed node will
create its own disjoint subgraph.
If set to :obj:`True`, mini-batch outputs will have a :obj:`batch`
vector holding the mapping of nodes to their respective subgraph.
Will get automatically set to :obj:`True` in case of temporal
sampling. (default: :obj:`False`)
temporal_strategy (string, optional): The sampling strategy when using
temporal sampling (:obj:`"uniform"`, :obj:`"last"`).
If set to :obj:`"uniform"`, will sample uniformly across neighbors
Expand Down Expand Up @@ -172,6 +178,7 @@ def __init__(
input_time: OptTensor = None,
replace: bool = False,
directed: bool = True,
disjoint: bool = False,
temporal_strategy: str = 'uniform',
time_attr: Optional[str] = None,
transform: Callable = None,
Expand All @@ -194,6 +201,7 @@ def __init__(
num_neighbors=num_neighbors,
replace=replace,
directed=directed,
disjoint=disjoint,
temporal_strategy=temporal_strategy,
input_type=node_type,
time_attr=time_attr,
Expand Down
46 changes: 26 additions & 20 deletions torch_geometric/sampler/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
num_neighbors: NumNeighbors,
replace: bool = False,
directed: bool = True,
disjoint: bool = False,
temporal_strategy: str = 'uniform',
input_type: Optional[Any] = None,
time_attr: Optional[str] = None,
Expand All @@ -47,6 +48,7 @@ def __init__(
self.num_neighbors = num_neighbors
self.replace = replace
self.directed = directed
self._disjoint = disjoint
self.temporal_strategy = temporal_strategy
self.node_time = self.node_time_dict = None
self.input_type = input_type
Expand Down Expand Up @@ -202,12 +204,14 @@ def _set_num_neighbors_and_num_hops(self, num_neighbors):
f"{self.num_hops} entries (got {len(value)})")

@property
def disjoint_sampling(self) -> bool:
"""Returns :obj:`True` if nodes have a time attribute. If :obj:`True`,
each seed node will create its own disjoint subgraph."""
def is_temporal(self) -> bool:
return (getattr(self, 'node_time') is not None
or getattr(self, 'node_time_dict') is not None)

@property
def disjoint(self) -> bool:
return self._disjoint or self.is_temporal

def _sample(
self,
seed: Union[torch.Tensor, Dict[NodeType, torch.Tensor]],
Expand All @@ -223,7 +227,6 @@ def _sample(
# TODO(manan): remote backends only support heterogeneous graphs:
if self.data_cls == 'custom' or issubclass(self.data_cls, HeteroData):
if _WITH_PYG_LIB:
# TODO (matthias) Add `disjoint` option to `NeighborSampler`
# TODO (matthias) `return_edge_id` if edge features present
out = torch.ops.pyg.hetero_neighbor_sample(
self.node_types,
Expand All @@ -237,20 +240,22 @@ def _sample(
True, # csc
self.replace,
self.directed,
self.disjoint_sampling,
self.disjoint,
self.temporal_strategy,
True, # return_edge_id
)
row, col, node, edge, batch = out + (None, )
if self.disjoint_sampling:
if self.disjoint:
node = {k: v.t().contiguous() for k, v in node.items()}
batch = {k: v[0] for k, v in node.items()}
node = {k: v[1] for k, v in node.items()}

else:
if self.node_time_dict is not None:
raise ValueError("'time_attr' not supported for "
"neighbor sampling via 'torch-sparse'")
if self.disjoint:
raise ValueError("'disjoint' sampling not supported for "
"neighbor sampling via 'torch-sparse'. "
"Please install 'pyg-lib' for improved "
"and optimized sampling routines.")
out = torch.ops.torch_sparse.hetero_neighbor_sample(
self.node_types,
self.edge_types,
Expand All @@ -274,7 +279,6 @@ def _sample(

if issubclass(self.data_cls, Data):
if _WITH_PYG_LIB:
# TODO (matthias) Add `disjoint` option to `NeighborSampler`
# TODO (matthias) `return_edge_id` if edge features present
out = torch.ops.pyg.neighbor_sample(
self.colptr,
Expand All @@ -286,18 +290,20 @@ def _sample(
True, # csc
self.replace,
self.directed,
self.disjoint_sampling,
self.disjoint,
self.temporal_strategy,
True, # return_edge_id
)
row, col, node, edge, batch = out + (None, )
if self.disjoint_sampling:
if self.disjoint:
batch, node = node.t().contiguous()

else:
if self.node_time is not None:
raise ValueError("'time_attr' not supported for "
"neighbor sampling via 'torch-sparse'")
if self.disjoint:
raise ValueError("'disjoint' sampling not supported for "
"neighbor sampling via 'torch-sparse'. "
"Please install 'pyg-lib' for improved "
"and optimized sampling routines.")
out = torch.ops.torch_sparse.neighbor_sample(
self.colptr,
self.row,
Expand Down Expand Up @@ -366,7 +372,7 @@ def sample_from_edges(
if (self.data_cls == 'custom'
or issubclass(self.data_cls, HeteroData)):
if self.input_type[0] != self.input_type[-1]:
if self.disjoint_sampling:
if self.disjoint:
seed_src = edge_label_index[0]
seed_dst = edge_label_index[1]
edge_label_index = torch.arange(
Expand Down Expand Up @@ -397,7 +403,7 @@ def sample_from_edges(
}

else: # Merge both source and destination node indices:
if self.disjoint_sampling:
if self.disjoint:
seed_nodes = edge_label_index.view(-1)
edge_label_index = torch.arange(0, 2 * num_seed_edges)
edge_label_index = edge_label_index.view(2, -1)
Expand All @@ -417,15 +423,15 @@ def sample_from_edges(
seed_time_dict=seed_time_dict,
)

if self.disjoint_sampling:
if self.disjoint:
for key, batch in output.batch.items():
output.batch[key] = batch % num_seed_edges

output.metadata = (index, edge_label_index, edge_label,
edge_label_time)

elif issubclass(self.data_cls, Data):
if self.disjoint_sampling:
if self.disjoint:
seed_nodes = edge_label_index.view(-1)
edge_label_index = torch.arange(0, 2 * num_seed_edges)
edge_label_index = edge_label_index.view(2, -1)
Expand All @@ -439,7 +445,7 @@ def sample_from_edges(

output = self._sample(seed=seed_nodes, seed_time=seed_time)

if self.disjoint_sampling:
if self.disjoint:
output.batch = output.batch % num_seed_edges

output.metadata = (index, edge_label_index, edge_label,
Expand Down

0 comments on commit 6a98f56

Please sign in to comment.