From 04575e3f5cd30575789de8cbe9ce81d4e4a0880a Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 14 Sep 2022 16:06:05 +0000 Subject: [PATCH 1/2] update --- test/loader/test_neighbor_loader.py | 24 +++++++++++++++++++-- torch_geometric/sampler/neighbor_sampler.py | 16 ++++++++------ 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/test/loader/test_neighbor_loader.py b/test/loader/test_neighbor_loader.py index 932f3fb7796b..1dfc8182caec 100644 --- a/test/loader/test_neighbor_loader.py +++ b/test/loader/test_neighbor_loader.py @@ -83,8 +83,28 @@ def test_heterogeneous_neighbor_loader(directed): ) batch_size = 20 - loader = NeighborLoader(data, num_neighbors=[10] * 2, input_nodes='paper', - batch_size=batch_size, directed=directed) + + with pytest.raises(ValueError, match="to have 2 entries"): + loader = NeighborLoader( + data, + num_neighbors={ + ('paper', 'paper'): [-1], + ('paper', 'author'): [-1, -1], + ('author', 'paper'): [-1, -1], + }, + input_nodes='paper', + batch_size=batch_size, + directed=directed, + ) + + loader = NeighborLoader( + data, + num_neighbors=[10] * 2, + input_nodes='paper', + batch_size=batch_size, + directed=directed, + ) + assert str(loader) == 'NeighborLoader()' assert len(loader) == (100 + batch_size - 1) // batch_size diff --git a/torch_geometric/sampler/neighbor_sampler.py b/torch_geometric/sampler/neighbor_sampler.py index ff9869f0b4b0..706a0b59129f 100644 --- a/torch_geometric/sampler/neighbor_sampler.py +++ b/torch_geometric/sampler/neighbor_sampler.py @@ -177,15 +177,17 @@ def __init__( def _set_num_neighbors_and_num_hops(self, num_neighbors): if isinstance(num_neighbors, (list, tuple)): - self.num_neighbors = { - key: num_neighbors - for key in self.edge_types - } - assert isinstance(self.num_neighbors, dict) + num_neighbors = {key: num_neighbors for key in self.edge_types} + assert isinstance(num_neighbors, dict) + self.num_neighbors = num_neighbors # Add at least one element to the list to ensure `max` is well-defined - self.num_hops = max([0] + - [len(v) for v in self.num_neighbors.values()]) + self.num_hops = max([0] + [len(v) for v in num_neighbors.values()]) + + for key, value in self.num_neighbors.items(): + if len(value) != self.num_hops: + raise ValueError(f"Expected the edge type {key} to have " + f"{self.num_hops} entries (got {len(value)})") def _sample( self, From c29a2219b46ad61ee214f619ac88566aecc0e34f Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 14 Sep 2022 16:07:05 +0000 Subject: [PATCH 2/2] changelog --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 70174781bb5c..31a1b0e95066 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240)) - Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222)) ### Changed -- Fixed a bug in `TUDataset` in which node features were wrongly constructed whenever `node_attributes` only hold a single feature (*e.g.*, in `PROTEINS`) ([#5441](https://github.com/pyg-team/pytorch_geometric/pull/5411)) +- Ensure equal lenghts of `num_neighbors` across edge types in `NeighborLoader` ([#5444](https://github.com/pyg-team/pytorch_geometric/pull/5444)) +- Fixed a bug in `TUDataset` in which node features were wrongly constructed whenever `node_attributes` only hold a single feature (*e.g.*, in `PROTEINS`) ([#5441](https://github.com/pyg-team/pytorch_geometric/pull/5441)) - Breaking change: removed `num_neighbors` as an attribute of loader ([#5404](https://github.com/pyg-team/pytorch_geometric/pull/5404)) - `ASAPooling` is now jittable ([#5395](https://github.com/pyg-team/pytorch_geometric/pull/5395)) - Updated unsupervised `GraphSAGE` example to leverage `LinkNeighborLoader` ([#5317](https://github.com/pyg-team/pytorch_geometric/pull/5317))