Skip to content

Commit

Permalink
Merge branch 'master' into models-lightgcn
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Jul 30, 2022
2 parents a8fdb99 + 25ff6d9 commit f87a57f
Show file tree
Hide file tree
Showing 38 changed files with 941 additions and 229 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/full_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ jobs:
include:
- torch-version: 1.12.0
torchvision-version: 0.13.0
- os: windows-latest
exclude:
- os: windows-latest # Complains about CUDA mismatch.
python-version: '3.7'
- os: windows-latest # Complains about missing numpy package.
python-version: '3.10'

steps:
- uses: actions/checkout@v3
Expand Down
12 changes: 10 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [2.0.5] - 2022-MM-DD
### Added
- Support `SparseTensor` as edge label in `LightGCN` (#[5046](https://github.com/pyg-team/pytorch_geometric/issues/5046))
- Added support for `BasicGNN` models within `to_hetero` ([#5091](https://github.com/pyg-team/pytorch_geometric/pull/5091))
- Added support for computing weighted metapaths in `AddMetapaths` ([#5049](https://github.com/pyg-team/pytorch_geometric/pull/5049))
- Added inference benchmark suite ([#4915](https://github.com/pyg-team/pytorch_geometric/pull/4915))
- Added a dynamically sized batch sampler for filling a mini-batch with a variable number of samples up to a maximum size ([#4972](https://github.com/pyg-team/pytorch_geometric/pull/4972))
- Added fine grained options for setting `bias` and `dropout` per layer in the `MLP` model ([#4981](https://github.com/pyg-team/pytorch_geometric/pull/4981))
Expand All @@ -24,7 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877), [#4908](https://github.com/pyg-team/pytorch_geometric/pull/4908))
- Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873))
- Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815), [#4862](https://github.com/pyg-team/pytorch_geometric/pull/4862/files))
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4929](https://github.com/pyg-team/pytorch_geometric/pull/4929), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922), [#4962](https://github.com/pyg-team/pytorch_geometric/pull/4962), [#4968](https://github.com/pyg-team/pytorch_geometric/pull/4968))
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4929](https://github.com/pyg-team/pytorch_geometric/pull/4929), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922), [#4962](https://github.com/pyg-team/pytorch_geometric/pull/4962), [#4968](https://github.com/pyg-team/pytorch_geometric/pull/4968), [#5037](https://github.com/pyg-team/pytorch_geometric/pull/5037))
- Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847))
- Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850))
- Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838))
Expand All @@ -38,7 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036), [#5039](https://github.com/pyg-team/pytorch_geometric/issues/5039), [#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522), [#5033](https://github.com/pyg-team/pytorch_geometric/pull/5033]), [#5085](https://github.com/pyg-team/pytorch_geometric/pull/5085))
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))
Expand All @@ -57,9 +59,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `HeteroData.is_undirected()` support ([#4604](https://github.com/pyg-team/pytorch_geometric/pull/4604))
- Added the `Genius` and `Wiki` datasets to `nn.datasets.LINKXDataset` ([#4570](https://github.com/pyg-team/pytorch_geometric/pull/4570), [#4600](https://github.com/pyg-team/pytorch_geometric/pull/4600))
- Added `nn.glob.GlobalPooling` module with support for multiple aggregations ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added `nn.glob.EquilibrumAggregation` implicit global layer ([#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522))
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
### Changed
- Fixed `HGTLoader` bug which produced outputs with missing edge types ([#5067](https://github.com/pyg-team/pytorch_geometric/pull/5067))
- Fixed dynamic inheritance issue in data batching ([#5051](https://github.com/pyg-team/pytorch_geometric/pull/5051))
- Fixed `load_state_dict` in `Linear` with `strict=False` mode ([5094](https://github.com/pyg-team/pytorch_geometric/pull/5094))
- Fixed typo in `MaskLabel.ratio_mask` ([5093](https://github.com/pyg-team/pytorch_geometric/pull/5093))
- Fixed `data.num_node_features` computation for sparse matrices ([5089](https://github.com/pyg-team/pytorch_geometric/pull/5089))
- Fixed `GenConv` test ([4993](https://github.com/pyg-team/pytorch_geometric/pull/4993))
- Fixed packaging tests for Python 3.10 ([4982](https://github.com/pyg-team/pytorch_geometric/pull/4982))
- Changed `act_dict` (part of `graphgym`) to create individual instances instead of reusing the same ones everywhere ([4978](https://github.com/pyg-team/pytorch_geometric/pull/4978))
Expand Down
15 changes: 0 additions & 15 deletions docs/source/modules/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,6 @@ Normalization Layers
:undoc-members:
:exclude-members: training

Global Pooling Layers
---------------------

.. currentmodule:: torch_geometric.nn.glob
.. autosummary::
:nosignatures:
{% for cls in torch_geometric.nn.glob.classes %}
{{ cls }}
{% endfor %}

.. automodule:: torch_geometric.nn.glob
:members:
:undoc-members:
:exclude-members: training

Pooling Layers
--------------

Expand Down
41 changes: 41 additions & 0 deletions examples/equilibrium_median.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
r"""
Replicates the experiment from `"Deep Graph Infomax"
<https://arxiv.org/abs/1809.10341>`_ to try and teach
`EquilibriumAggregation` to learn to take the median of
a set of numbers
This example converges slowly to being able to predict the
median similar to what is observed in the paper.
"""

import numpy as np
import torch

from torch_geometric.nn.aggr import EquilibriumAggregation

input_size = 100
steps = 10000000
embedding_size = 10
eval_each = 1000

model = EquilibriumAggregation(1, 10, [256, 256], 1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

norm = torch.distributions.normal.Normal(0.5, 0.4)
gamma = torch.distributions.gamma.Gamma(0.2, 0.5)
uniform = torch.distributions.uniform.Uniform(0, 1)
total_loss = 0
n_loss = 0

for i in range(steps):
optimizer.zero_grad()
dist = np.random.choice([norm, gamma, uniform])
x = dist.sample((input_size, 1))
y = model(x)
loss = (y - x.median()).norm(2) / input_size
loss.backward()
optimizer.step()
total_loss += loss
n_loss += 1
if i % eval_each == (eval_each - 1):
print(f"Average loss at epoc {i} is {total_loss / n_loss}")
2 changes: 1 addition & 1 deletion examples/link_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test(data):
val_auc = test(val_data)
test_auc = test(test_data)
if val_auc > best_val_auc:
best_val = val_auc
best_val_auc = val_auc
final_test_auc = test_auc
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
f'Test: {test_auc:.4f}')
Expand Down
2 changes: 1 addition & 1 deletion examples/unimp_arxiv.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def forward(self, x, y, edge_index, label_mask):
def train(label_rate=0.65): # How many labels to use for propagation.
model.train()

propagation_mask = MaskLabel.ratio_mask(train_mask, ratio=1 - label_rate)
propagation_mask = MaskLabel.ratio_mask(train_mask, ratio=label_rate)
supervision_mask = train_mask ^ propagation_mask

optimizer.zero_grad()
Expand Down
21 changes: 21 additions & 0 deletions test/data/test_graph_store.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import List

import pytest
import torch
from torch_sparse import SparseTensor

from torch_geometric.data.graph_store import EdgeLayout
from torch_geometric.testing.graph_store import MyGraphStore
from torch_geometric.typing import OptTensor
from torch_geometric.utils.sort_edge_index import sort_edge_index


Expand Down Expand Up @@ -94,3 +97,21 @@ def assert_edge_index_equal(expected: torch.Tensor, actual: torch.Tensor):
assert torch.equal(row_dict[key], csc[0])
assert torch.equal(colptr_dict[key], csc[1])
assert perm_dict[key] is None

# Ensure that 'edge_types' parameters work as intended:
def _tensor_eq(expected: List[OptTensor], actual: List[OptTensor]):
for tensor_expected, tensor_actual in zip(expected, actual):
if tensor_expected is None or tensor_actual is None:
return tensor_actual == tensor_expected
return torch.equal(tensor_expected, tensor_actual)

edge_types = [('v', '1', 'v'), ('v', '2', 'v')]
assert _tensor_eq(
list(graph_store.coo()[0].values())[:-1],
graph_store.coo(edge_types=edge_types)[0].values())
assert _tensor_eq(
list(graph_store.csr()[0].values())[:-1],
graph_store.csr(edge_types=edge_types)[0].values())
assert _tensor_eq(
list(graph_store.csc()[0].values())[:-1],
graph_store.csc(edge_types=edge_types)[0].values())
11 changes: 11 additions & 0 deletions test/loader/test_hgt_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch_geometric.data import HeteroData
from torch_geometric.loader import HGTLoader
from torch_geometric.nn import GraphConv, to_hetero
from torch_geometric.testing import withPackage
from torch_geometric.utils import k_hop_subgraph


Expand Down Expand Up @@ -174,3 +175,13 @@ def forward(self, x, edge_index, edge_weight):
out2 = hetero_model(hetero_batch.x_dict, hetero_batch.edge_index_dict,
hetero_batch.edge_weight_dict)['paper'][:batch_size]
assert torch.allclose(out1, out2, atol=1e-6)


@withPackage('torch_sparse>=0.6.15')
def test_hgt_loader_on_dblp(get_dataset):
data = get_dataset(name='dblp')[0]
loader = HGTLoader(data, num_samples=[10, 10],
input_nodes=('author', data['author'].train_mask))

for batch in loader:
assert set(batch.edge_types) == set(data.edge_types)
74 changes: 74 additions & 0 deletions test/loader/test_link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from torch_geometric.data import Data, HeteroData
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.testing import withRegisteredOp
from torch_geometric.testing.feature_store import MyFeatureStore
from torch_geometric.testing.graph_store import MyGraphStore


def get_edge_index(num_src_nodes, num_dst_nodes, num_edges):
Expand Down Expand Up @@ -213,3 +215,75 @@ def test_temporal_heterogeneous_link_neighbor_loader():
seed_nodes = batch['paper', 'paper'].edge_label_index.view(-1)
seed_max_time = batch['paper'].time[seed_nodes].max()
assert seed_max_time >= max_time


@pytest.mark.parametrize('FeatureStore', [MyFeatureStore, HeteroData])
@pytest.mark.parametrize('GraphStore', [MyGraphStore, HeteroData])
def test_custom_heterogeneous_link_neighbor_loader(FeatureStore, GraphStore):
data = HeteroData()
feature_store = FeatureStore()
graph_store = GraphStore()

# Set up node features:
x = torch.arange(100)
data['paper'].x = x
feature_store.put_tensor(x, group_name='paper', attr_name='x', index=None)

x = torch.arange(100, 300)
data['author'].x = x
feature_store.put_tensor(x, group_name='author', attr_name='x', index=None)

# Set up edge indices (GraphStore does not support `edge_attr` at the
# moment):
edge_index = get_edge_index(100, 100, 500)
data['paper', 'to', 'paper'].edge_index = edge_index
graph_store.put_edge_index(edge_index=(edge_index[0], edge_index[1]),
edge_type=('paper', 'to', 'paper'),
layout='coo', size=(100, 100))

edge_index = get_edge_index(100, 200, 1000)
data['paper', 'to', 'author'].edge_index = edge_index
graph_store.put_edge_index(edge_index=(edge_index[0], edge_index[1]),
edge_type=('paper', 'to', 'author'),
layout='coo', size=(100, 200))

edge_index = get_edge_index(200, 100, 1000)
data['author', 'to', 'paper'].edge_index = edge_index
graph_store.put_edge_index(edge_index=(edge_index[0], edge_index[1]),
edge_type=('author', 'to', 'paper'),
layout='coo', size=(200, 100))

loader1 = LinkNeighborLoader(
data,
num_neighbors=[-1] * 2,
edge_label_index=('paper', 'to', 'author'),
batch_size=20,
directed=True,
neg_sampling_ratio=0,
)

loader2 = LinkNeighborLoader(
(feature_store, graph_store),
num_neighbors=[-1] * 2,
edge_label_index=('paper', 'to', 'author'),
batch_size=20,
directed=True,
neg_sampling_ratio=0,
)

assert str(loader1) == str(loader2)

for (batch1, batch2) in zip(loader1, loader2):
# Mapped indices of neighbors may be differently sorted:
assert torch.allclose(batch1['paper'].x.sort()[0],
batch2['paper'].x.sort()[0])
assert torch.allclose(batch1['author'].x.sort()[0],
batch2['author'].x.sort()[0])

# Assert that edge indices have the same size:
assert (batch1['paper', 'to', 'paper'].edge_index.size() == batch1[
'paper', 'to', 'paper'].edge_index.size())
assert (batch1['paper', 'to', 'author'].edge_index.size() == batch1[
'paper', 'to', 'author'].edge_index.size())
assert (batch1['author', 'to', 'paper'].edge_index.size() == batch1[
'author', 'to', 'paper'].edge_index.size())
53 changes: 53 additions & 0 deletions test/nn/aggr/test_equilibrium.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
import torch

from torch_geometric.nn.aggr import EquilibriumAggregation


@pytest.mark.parametrize('iter', [0, 1, 5])
@pytest.mark.parametrize('alpha', [0, .1, 5])
def test_equilibrium(iter, alpha):

batch_size = 10
feature_channels = 3
output_channels = 2
x = torch.randn(batch_size, feature_channels)
model = EquilibriumAggregation(feature_channels, output_channels,
num_layers=[10, 10], grad_iter=iter)

assert model.__repr__() == 'EquilibriumAggregation()'
out = model(x)
assert out.size() == (1, 2)

with pytest.raises(ValueError):
model(x, dim_size=0)

out = model(x, dim_size=3)
assert out.size() == (3, 2)
assert torch.all(out[1:, :] == 0)


@pytest.mark.parametrize('iter', [0, 1, 5])
@pytest.mark.parametrize('alpha', [0, .1, 5])
def test_equilibrium_batch(iter, alpha):

batch_1, batch_2 = 4, 6
feature_channels = 3
output_channels = 2
x = torch.randn(batch_1 + batch_2, feature_channels)
batch = torch.tensor([0 for _ in range(batch_1)] +
[1 for _ in range(batch_2)])

model = EquilibriumAggregation(feature_channels, output_channels,
num_layers=[10, 10], grad_iter=iter)

assert model.__repr__() == 'EquilibriumAggregation()'
out = model(x, batch)
assert out.size() == (2, 2)

with pytest.raises(ValueError):
model(x, dim_size=0)

out = model(x, dim_size=3)
assert out.size() == (3, 2)
assert torch.all(out[1:, :] == 0)
Loading

0 comments on commit f87a57f

Please sign in to comment.