-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into models-lightgcn
- Loading branch information
Showing
38 changed files
with
941 additions
and
229 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
r""" | ||
Replicates the experiment from `"Deep Graph Infomax" | ||
<https://arxiv.org/abs/1809.10341>`_ to try and teach | ||
`EquilibriumAggregation` to learn to take the median of | ||
a set of numbers | ||
This example converges slowly to being able to predict the | ||
median similar to what is observed in the paper. | ||
""" | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from torch_geometric.nn.aggr import EquilibriumAggregation | ||
|
||
input_size = 100 | ||
steps = 10000000 | ||
embedding_size = 10 | ||
eval_each = 1000 | ||
|
||
model = EquilibriumAggregation(1, 10, [256, 256], 1) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | ||
|
||
norm = torch.distributions.normal.Normal(0.5, 0.4) | ||
gamma = torch.distributions.gamma.Gamma(0.2, 0.5) | ||
uniform = torch.distributions.uniform.Uniform(0, 1) | ||
total_loss = 0 | ||
n_loss = 0 | ||
|
||
for i in range(steps): | ||
optimizer.zero_grad() | ||
dist = np.random.choice([norm, gamma, uniform]) | ||
x = dist.sample((input_size, 1)) | ||
y = model(x) | ||
loss = (y - x.median()).norm(2) / input_size | ||
loss.backward() | ||
optimizer.step() | ||
total_loss += loss | ||
n_loss += 1 | ||
if i % eval_each == (eval_each - 1): | ||
print(f"Average loss at epoc {i} is {total_loss / n_loss}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import pytest | ||
import torch | ||
|
||
from torch_geometric.nn.aggr import EquilibriumAggregation | ||
|
||
|
||
@pytest.mark.parametrize('iter', [0, 1, 5]) | ||
@pytest.mark.parametrize('alpha', [0, .1, 5]) | ||
def test_equilibrium(iter, alpha): | ||
|
||
batch_size = 10 | ||
feature_channels = 3 | ||
output_channels = 2 | ||
x = torch.randn(batch_size, feature_channels) | ||
model = EquilibriumAggregation(feature_channels, output_channels, | ||
num_layers=[10, 10], grad_iter=iter) | ||
|
||
assert model.__repr__() == 'EquilibriumAggregation()' | ||
out = model(x) | ||
assert out.size() == (1, 2) | ||
|
||
with pytest.raises(ValueError): | ||
model(x, dim_size=0) | ||
|
||
out = model(x, dim_size=3) | ||
assert out.size() == (3, 2) | ||
assert torch.all(out[1:, :] == 0) | ||
|
||
|
||
@pytest.mark.parametrize('iter', [0, 1, 5]) | ||
@pytest.mark.parametrize('alpha', [0, .1, 5]) | ||
def test_equilibrium_batch(iter, alpha): | ||
|
||
batch_1, batch_2 = 4, 6 | ||
feature_channels = 3 | ||
output_channels = 2 | ||
x = torch.randn(batch_1 + batch_2, feature_channels) | ||
batch = torch.tensor([0 for _ in range(batch_1)] + | ||
[1 for _ in range(batch_2)]) | ||
|
||
model = EquilibriumAggregation(feature_channels, output_channels, | ||
num_layers=[10, 10], grad_iter=iter) | ||
|
||
assert model.__repr__() == 'EquilibriumAggregation()' | ||
out = model(x, batch) | ||
assert out.size() == (2, 2) | ||
|
||
with pytest.raises(ValueError): | ||
model(x, dim_size=0) | ||
|
||
out = model(x, dim_size=3) | ||
assert out.size() == (3, 2) | ||
assert torch.all(out[1:, :] == 0) |
Oops, something went wrong.