Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NeighborSampler: Sort local neighborhoods according to time #5516

Merged
merged 9 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240))
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
### Changed
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516))
- Fixed a bug when applying several scalers with `PNAConv` ([#5514](https://github.com/pyg-team/pytorch_geometric/issues/5514))
- Allow `.` in `ParameterDict` key names ([#5494](https://github.com/pyg-team/pytorch_geometric/pull/5494))
- Renamed `drop_unconnected_nodes` to `drop_unconnected_node_types` and `drop_orig_edges` to `drop_orig_edge_types` in `AddMetapaths` ([#5490](https://github.com/pyg-team/pytorch_geometric/pull/5490))
Expand Down
2 changes: 0 additions & 2 deletions test/loader/test_link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from torch_geometric.data import Data, HeteroData
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.testing import withPackage
from torch_geometric.testing.feature_store import MyFeatureStore
from torch_geometric.testing.graph_store import MyGraphStore

Expand Down Expand Up @@ -182,7 +181,6 @@ def test_link_neighbor_loader_edge_label():
assert torch.all(batch.edge_label[10:] == 0)


@withPackage('torch_sparse>=0.6.14')
def test_temporal_heterogeneous_link_neighbor_loader():
data = HeteroData()

Expand Down
51 changes: 35 additions & 16 deletions test/loader/test_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,14 +280,13 @@ def forward(self, x, edge_index, edge_weight):
assert torch.allclose(out1, out2, atol=1e-6)


@withPackage('torch_sparse>=0.6.14')
def test_temporal_heterogeneous_neighbor_loader_on_cora(get_dataset):
dataset = get_dataset(name='Cora')
data = dataset[0]

hetero_data = HeteroData()
hetero_data['paper'].x = data.x
hetero_data['paper'].time = torch.arange(data.num_nodes)
hetero_data['paper'].time = torch.arange(data.num_nodes, 0, -1)
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
hetero_data['paper', 'paper'].edge_index = data.edge_index

loader = NeighborLoader(hetero_data, num_neighbors=[-1, -1],
Expand Down Expand Up @@ -381,37 +380,57 @@ def test_custom_neighbor_loader(FeatureStore, GraphStore):
'author', 'to', 'paper'].edge_index.size())


@withPackage('torch_sparse>=0.6.14')
@pytest.mark.parametrize('FeatureStore', [MyFeatureStore, HeteroData])
@pytest.mark.parametrize('GraphStore', [MyGraphStore, HeteroData])
def test_temporal_custom_neighbor_loader_on_cora(get_dataset, FeatureStore,
GraphStore):
# Initialize dataset (once):
dataset = get_dataset(name='Cora')
data = dataset[0]
data.time = torch.arange(data.num_nodes, 0, -1)

# Initialize feature store, graph store, and reference:
feature_store = FeatureStore()
graph_store = GraphStore()
hetero_data = HeteroData()

feature_store.put_tensor(data.x, group_name='paper', attr_name='x',
index=None)
feature_store.put_tensor(
data.x,
group_name='paper',
attr_name='x',
index=None,
)
hetero_data['paper'].x = data.x

feature_store.put_tensor(torch.arange(data.num_nodes), group_name='paper',
attr_name='time', index=None)
hetero_data['paper'].time = torch.arange(data.num_nodes)

num_nodes = data.x.size(dim=0)
graph_store.put_edge_index(edge_index=data.edge_index,
edge_type=('paper', 'to', 'paper'),
layout='coo', size=(num_nodes, num_nodes))
feature_store.put_tensor(
data.time,
group_name='paper',
attr_name='time',
index=None,
)
hetero_data['paper'].time = data.time

# Sort according to time in local neighborhoods:
row, col = data.edge_index
perm = ((col * (data.num_nodes + 1)) + data.time[row]).argsort()
edge_index = data.edge_index[:, perm]
rusty1s marked this conversation as resolved.
Show resolved Hide resolved

graph_store.put_edge_index(
edge_index,
edge_type=('paper', 'to', 'paper'),
layout='coo',
is_sorted=True,
size=(data.num_nodes, data.num_nodes),
)
hetero_data['paper', 'to', 'paper'].edge_index = data.edge_index

loader1 = NeighborLoader(hetero_data, num_neighbors=[-1, -1],
input_nodes='paper', time_attr='time',
batch_size=128)
loader1 = NeighborLoader(
hetero_data,
num_neighbors=[-1, -1],
input_nodes='paper',
time_attr='time',
batch_size=128,
)

loader2 = NeighborLoader(
(feature_store, graph_store),
Expand Down
8 changes: 5 additions & 3 deletions torch_geometric/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,11 @@ class LinkNeighborLoader(LinkLoader):
a sampled mini-batch and returns a transformed version.
(default: :obj:`None`)
is_sorted (bool, optional): If set to :obj:`True`, assumes that
:obj:`edge_index` is sorted by column. This avoids internal
re-sorting of the data and can improve runtime and memory
efficiency. (default: :obj:`False`)
:obj:`edge_index` is sorted by column.
If :obj:`time_attr` is set, additionally requires that rows are
sorted according to time within individual neighborhoods.
This avoids internal re-sorting of the data and can improve
runtime and memory efficiency. (default: :obj:`False`)
filter_per_worker (bool, optional): If set to :obj:`True`, will filter
the returning data in each worker's subprocess rather than in the
main process.
Expand Down
8 changes: 5 additions & 3 deletions torch_geometric/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,11 @@ class NeighborLoader(NodeLoader):
a sampled mini-batch and returns a transformed version.
(default: :obj:`None`)
is_sorted (bool, optional): If set to :obj:`True`, assumes that
:obj:`edge_index` is sorted by column. This avoids internal
re-sorting of the data and can improve runtime and memory
efficiency. (default: :obj:`False`)
:obj:`edge_index` is sorted by column.
If :obj:`time_attr` is set, additionally requires that rows are
sorted according to time within individual neighborhoods.
This avoids internal re-sorting of the data and can improve
runtime and memory efficiency. (default: :obj:`False`)
filter_per_worker (bool, optional): If set to :obj:`True`, will filter
the returning data in each worker's subprocess rather than in the
main process.
Expand Down
33 changes: 24 additions & 9 deletions torch_geometric/sampler/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from torch_geometric.data import Data, HeteroData, remote_backend_utils
from torch_geometric.data.feature_store import FeatureStore
from torch_geometric.data.graph_store import GraphStore
from torch_geometric.data.graph_store import EdgeLayout, GraphStore
from torch_geometric.sampler.base import (
BaseSampler,
EdgeSamplerInput,
Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(

# Convert the graph data into a suitable format for sampling.
out = to_csc(data, device='cpu', share_memory=share_memory,
is_sorted=is_sorted)
is_sorted=is_sorted, src_node_time=self.node_time)
self.colptr, self.row, self.perm = out
assert isinstance(num_neighbors, (list, tuple))

Expand All @@ -99,7 +99,8 @@ def __init__(

# Obtain CSC representations for in-memory sampling:
out = to_hetero_csc(data, device='cpu', share_memory=share_memory,
is_sorted=is_sorted)
is_sorted=is_sorted,
node_time_dict=self.node_time_dict)
colptr_dict, row_dict, perm_dict = out

# Conversions to/from C++ string type:
Expand All @@ -125,16 +126,34 @@ def __init__(
# TODO support `FeatureStore` with no edge types (e.g. `Data`)
feature_store, graph_store = data

# Obtain all node and edge metadata:
node_attrs = feature_store.get_all_tensor_attrs()
edge_attrs = graph_store.get_all_edge_attrs()

# TODO support `collect` on `FeatureStore`:
self.node_time_dict = None
if time_attr is not None:
# If the `time_attr` is present, we expect that `GraphStore`
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
# holds all edges sorted by destination, and within local
# neighborhoods, node indices should be sorted by time.
# TODO (matthias, manan) Find an alternative way to ensure
for edge_attr in edge_attrs:
if edge_attr.layout == EdgeLayout.CSR:
raise ValueError(
"Temporal sampling requires that edges are stored "
"in either COO or CSC layout")
if not edge_attr.is_sorted:
raise ValueError(
"Temporal sampling requires that edges are "
"sorted by destination, and by source time "
"within local neighborhoods")

# We need to obtain all features with 'attr_name=time_attr'
# from the feature store and store them in node_time_dict. To
# do so, we make an explicit feature store GET call here with
# the relevant 'TensorAttr's
time_attrs = [
attr for attr in feature_store.get_all_tensor_attrs()
if attr.attr_name == time_attr
attr for attr in node_attrs if attr.attr_name == time_attr
]
for attr in time_attrs:
attr.index = None
Expand All @@ -144,10 +163,6 @@ def __init__(
for time_attr, time_tensor in zip(time_attrs, time_tensors)
}

# Obtain all node and edge metadata:
node_attrs = feature_store.get_all_tensor_attrs()
edge_attrs = graph_store.get_all_edge_attrs()

self.node_types = list(
set(node_attr.group_name for node_attr in node_attrs))
self.edge_types = list(
Expand Down
45 changes: 36 additions & 9 deletions torch_geometric/sampler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,37 @@

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.storage import EdgeStorage
from torch_geometric.typing import EdgeType, OptTensor
from torch_geometric.typing import EdgeType, NodeType, OptTensor

# Edge Layout Conversion ######################################################


def sort_csc(
row: Tensor,
col: Tensor,
src_node_time: OptTensor = None,
) -> Tuple[Tensor, Tensor, Tensor]:
if src_node_time is None:
col, perm = col.sort()
return row[perm], col, perm
else:
# Multiplying by raw `datetime[64ns]` values may cause overflows.
# As such, we normalize time into range [0, 1) before sorting:
src_node_time = src_node_time.to(torch.double, copy=True)
min_time, max_time = src_node_time.min(), src_node_time.max() + 1
src_node_time.sub_(min_time).div_(max_time)

perm = src_node_time[row].add_(col.to(torch.double)).argsort()
return row[perm], col[perm], perm


# TODO(manan) deprecate when FeatureStore / GraphStore unification is complete
def to_csc(
data: Union[Data, EdgeStorage],
device: Optional[torch.device] = None,
share_memory: bool = False,
is_sorted: bool = False,
src_node_time: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, OptTensor]:
# Convert the graph data into a suitable format for sampling (CSC format).
# Returns the `colptr` and `row` indices of the graph, as well as an
Expand All @@ -27,17 +47,23 @@ def to_csc(
perm: Optional[Tensor] = None

if hasattr(data, 'adj'):
if src_node_time is not None:
raise NotImplementedError("Temporal sampling via 'SparseTensor' "
"format not yet supported")
colptr, row, _ = data.adj.csc()

elif hasattr(data, 'adj_t'):
if src_node_time is not None:
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError("Temporal sampling via 'SparseTensor' "
"format not yet supported")
colptr, row, _ = data.adj_t.csr()

elif data.edge_index is not None:
(row, col) = data.edge_index
row, col = data.edge_index
if not is_sorted:
perm = (col * data.size(0)).add_(row).argsort()
row = row[perm]
colptr = torch.ops.torch_sparse.ind2ptr(col[perm], data.size(1))
row, col, perm = sort_csc(row, col, src_node_time)
colptr = torch.ops.torch_sparse.ind2ptr(col, data.size(1))

else:
row = torch.empty(0, dtype=torch.long, device=device)
colptr = torch.zeros(data.num_nodes + 1, dtype=torch.long,
Expand All @@ -61,17 +87,18 @@ def to_hetero_csc(
device: Optional[torch.device] = None,
share_memory: bool = False,
is_sorted: bool = False,
node_time_dict: Optional[Dict[NodeType, Tensor]] = None,
) -> Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, OptTensor]]:
# Convert the heterogeneous graph data into a suitable format for sampling
# (CSC format).
# Returns dictionaries holding `colptr` and `row` indices as well as edge
# permutations for each edge type, respectively.
colptr_dict, row_dict, perm_dict = {}, {}, {}

for store in data.edge_stores:
key = store._key
out = to_csc(store, device, share_memory, is_sorted)
colptr_dict[key], row_dict[key], perm_dict[key] = out
for edge_type, store in data.edge_items():
src_node_time = (node_time_dict or {}).get(edge_type[0], None)
out = to_csc(store, device, share_memory, is_sorted, src_node_time)
colptr_dict[edge_type], row_dict[edge_type], perm_dict[edge_type] = out

return colptr_dict, row_dict, perm_dict

Expand Down