Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit f8de132

Browse files
authoredAug 8, 2022
Merge branch 'master' into master
2 parents ba6cdf2 + 7692969 commit f8de132

File tree

18 files changed

+75
-39
lines changed

18 files changed

+75
-39
lines changed
 

‎CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66
## [2.0.5] - 2022-MM-DD
77
### Added
88
- Let `ImbalancedSampler` accept `torch.Tensor` as input ([#5138](https://github.com/pyg-team/pytorch_geometric/pull/5138))
9+
- Added `flow` argument to `gcn_norm` to correctly normalize the adjacency matrix in `GCNConv` ([#5149](https://github.com/pyg-team/pytorch_geometric/pull/5149))
910
- `NeighborSampler` supports graphs without edges ([#5072](https://github.com/pyg-team/pytorch_geometric/pull/5072))
1011
- Added the `MeanSubtractionNorm` layer ([#5068](https://github.com/pyg-team/pytorch_geometric/pull/5068))
1112
- Added `pyg_lib.segment_matmul` integration within `RGCNConv` ([#5052](https://github.com/pyg-team/pytorch_geometric/pull/5052), [#5096](https://github.com/pyg-team/pytorch_geometric/pull/5096))
@@ -66,6 +67,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6667
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
6768
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
6869
### Changed
70+
- Fixed `GINEConv` bug with non-sequential input ([#5154](https://github.com/pyg-team/pytorch_geometric/pull/5154)]
6971
- Improved error message ([#5095](https://github.com/pyg-team/pytorch_geometric/pull/5095))
7072
- Fixed `HGTLoader` bug which produced outputs with missing edge types ([#5067](https://github.com/pyg-team/pytorch_geometric/pull/5067))
7173
- Fixed dynamic inheritance issue in data batching ([#5051](https://github.com/pyg-team/pytorch_geometric/pull/5051))
@@ -101,7 +103,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
101103
- Refactored reading molecular positions from sdf file for qm9 datasets ([4654](https://github.com/pyg-team/pytorch_geometric/pull/4654))
102104
- Fixed `MLP.jittable()` bug in case `return_emb=True` ([#4645](https://github.com/pyg-team/pytorch_geometric/pull/4645), [#4648](https://github.com/pyg-team/pytorch_geometric/pull/4648))
103105
- The generated node features of `StochasticBlockModelDataset` are now ordered with respect to their labels ([#4617](https://github.com/pyg-team/pytorch_geometric/pull/4617))
104-
- Fixed typos in the documentation ([#4616](https://github.com/pyg-team/pytorch_geometric/pull/4616), [#4824](https://github.com/pyg-team/pytorch_geometric/pull/4824), [#4895](https://github.com/pyg-team/pytorch_geometric/pull/4895))
106+
- Fixed typos in the documentation ([#4616](https://github.com/pyg-team/pytorch_geometric/pull/4616), [#4824](https://github.com/pyg-team/pytorch_geometric/pull/4824), [#4895](https://github.com/pyg-team/pytorch_geometric/pull/4895), [#5161](https://github.com/pyg-team/pytorch_geometric/pull/5161))
105107
- The `bias` argument in `TAGConv` is now actually applied ([#4597](https://github.com/pyg-team/pytorch_geometric/pull/4597))
106108
- Fixed subclass behaviour of `process` and `download` in `Datsaet` ([#4586](https://github.com/pyg-team/pytorch_geometric/pull/4586))
107109
- Fixed filtering of attributes for loaders in case `__cat_dim__ != 0` ([#4629](https://github.com/pyg-team/pytorch_geometric/pull/4629))

‎test/nn/conv/test_gcn_conv.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,15 @@ def test_static_gcn_conv():
7878
conv = GCNConv(16, 32)
7979
out = conv(x, edge_index)
8080
assert out.size() == (3, 4, 32)
81+
82+
83+
def test_gcn_conv_norm():
84+
x = torch.randn(4, 16)
85+
edge_index = torch.tensor([[0, 0, 0], [1, 2, 3]])
86+
row, col = edge_index
87+
88+
conv = GCNConv(16, 32, flow="source_to_target")
89+
out1 = conv(x, edge_index)
90+
conv.flow = "target_to_source"
91+
out2 = conv(x, edge_index.flip(0))
92+
assert torch.allclose(out1, out2, atol=1e-6)

‎test/nn/conv/test_gin_conv.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ def test_gine_conv_edge_dim():
126126
out = conv(x, edge_index, edge_attr)
127127
assert out.size() == (4, 32)
128128

129+
nn = Lin(16, 32)
130+
conv = GINEConv(nn, train_eps=True, edge_dim=8)
131+
out = conv(x, edge_index, edge_attr)
132+
assert out.size() == (4, 32)
133+
129134

130135
def test_static_gin_conv():
131136
x = torch.randn(3, 4, 16)

‎torch_geometric/graphgym/utils/agg_runs.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,18 @@ def agg_dict_list(dict_list):
5959

6060

6161
def name_to_dict(run):
62-
cols = run.split('-')[1:]
62+
run = run.split('-', 1)[-1]
63+
cols = run.split('=')
6364
keys, vals = [], []
64-
for col in cols:
65+
keys.append(cols[0])
66+
for col in cols[1:-1]:
6567
try:
66-
key, val = col.split('=')
68+
val, key = col.rsplit('-', 1)
6769
except Exception:
6870
print(col)
6971
keys.append(key)
7072
vals.append(string_to_python(val))
73+
vals.append(cols[-1])
7174
return dict(zip(keys, vals))
7275

7376

‎torch_geometric/loader/hgt_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class HGTLoader(torch.utils.data.DataLoader):
5252
num_samples={key: [512] * 4 for key in hetero_data.node_types},
5353
# Use a batch size of 128 for sampling training nodes of type paper
5454
batch_size=128,
55-
input_nodes=('paper': hetero_data['paper'].train_mask),
55+
input_nodes=('paper', hetero_data['paper'].train_mask),
5656
)
5757
5858
sampled_hetero_data = next(iter(loader))

‎torch_geometric/nn/aggr/equilibrium.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def forward(self, x: Tensor, y: Tensor, index: Optional[Tensor]) -> Tensor:
4141
return h.mean()
4242

4343
size = int(index.max().item() + 1)
44-
return scatter(x, index, dim=0, dim_size=size, reduce='mean').sum()
44+
return scatter(h, index, dim=0, dim_size=size, reduce='mean').sum()
4545

4646

4747
class MomentumOptimizer(torch.nn.Module):

‎torch_geometric/nn/conv/appnp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def forward(self, x: Tensor, edge_index: Adj,
8585
if cache is None:
8686
edge_index, edge_weight = gcn_norm( # yapf: disable
8787
edge_index, edge_weight, x.size(self.node_dim), False,
88-
self.add_self_loops, dtype=x.dtype)
88+
self.add_self_loops, self.flow, dtype=x.dtype)
8989
if self.cached:
9090
self._cached_edge_index = (edge_index, edge_weight)
9191
else:
@@ -96,7 +96,7 @@ def forward(self, x: Tensor, edge_index: Adj,
9696
if cache is None:
9797
edge_index = gcn_norm( # yapf: disable
9898
edge_index, edge_weight, x.size(self.node_dim), False,
99-
self.add_self_loops, dtype=x.dtype)
99+
self.add_self_loops, self.flow, dtype=x.dtype)
100100
if self.cached:
101101
self._cached_adj_t = edge_index
102102
else:

‎torch_geometric/nn/conv/arma_conv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,12 @@ def forward(self, x: Tensor, edge_index: Adj,
108108
if isinstance(edge_index, Tensor):
109109
edge_index, edge_weight = gcn_norm( # yapf: disable
110110
edge_index, edge_weight, x.size(self.node_dim),
111-
add_self_loops=False, dtype=x.dtype)
111+
add_self_loops=False, flow=self.flow, dtype=x.dtype)
112112

113113
elif isinstance(edge_index, SparseTensor):
114114
edge_index = gcn_norm( # yapf: disable
115115
edge_index, edge_weight, x.size(self.node_dim),
116-
add_self_loops=False, dtype=x.dtype)
116+
add_self_loops=False, flow=self.flow, dtype=x.dtype)
117117

118118
x = x.unsqueeze(-3)
119119
out = x

‎torch_geometric/nn/conv/dna_conv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def forward(self, x: Tensor, edge_index: Adj,
275275
if cache is None:
276276
edge_index, edge_weight = gcn_norm( # yapf: disable
277277
edge_index, edge_weight, x.size(self.node_dim), False,
278-
self.add_self_loops, dtype=x.dtype)
278+
self.add_self_loops, self.flow, dtype=x.dtype)
279279
if self.cached:
280280
self._cached_edge_index = (edge_index, edge_weight)
281281
else:
@@ -286,7 +286,7 @@ def forward(self, x: Tensor, edge_index: Adj,
286286
if cache is None:
287287
edge_index = gcn_norm( # yapf: disable
288288
edge_index, edge_weight, x.size(self.node_dim), False,
289-
self.add_self_loops, dtype=x.dtype)
289+
self.add_self_loops, self.flow, dtype=x.dtype)
290290
if self.cached:
291291
self._cached_adj_t = edge_index
292292
else:

‎torch_geometric/nn/conv/eg_conv.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
130130
if cache is None:
131131
edge_index, symnorm_weight = gcn_norm( # yapf: disable
132132
edge_index, None, num_nodes=x.size(self.node_dim),
133-
improved=False, add_self_loops=self.add_self_loops)
133+
improved=False, add_self_loops=self.add_self_loops,
134+
flow=self.flow)
134135
if self.cached:
135136
self._cached_edge_index = (edge_index, symnorm_weight)
136137
else:
@@ -141,7 +142,8 @@ def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
141142
if cache is None:
142143
edge_index = gcn_norm( # yapf: disable
143144
edge_index, None, num_nodes=x.size(self.node_dim),
144-
improved=False, add_self_loops=self.add_self_loops)
145+
improved=False, add_self_loops=self.add_self_loops,
146+
flow=self.flow)
145147
if self.cached:
146148
self._cached_adj_t = edge_index
147149
else:

‎torch_geometric/nn/conv/fa_conv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def forward(self, x: Tensor, x_0: Tensor, edge_index: Adj,
115115
if cache is None:
116116
edge_index, edge_weight = gcn_norm( # yapf: disable
117117
edge_index, None, x.size(self.node_dim), False,
118-
self.add_self_loops, dtype=x.dtype)
118+
self.add_self_loops, self.flow, dtype=x.dtype)
119119
if self.cached:
120120
self._cached_edge_index = (edge_index, edge_weight)
121121
else:
@@ -127,7 +127,7 @@ def forward(self, x: Tensor, x_0: Tensor, edge_index: Adj,
127127
if cache is None:
128128
edge_index = gcn_norm( # yapf: disable
129129
edge_index, None, x.size(self.node_dim), False,
130-
self.add_self_loops, dtype=x.dtype)
130+
self.add_self_loops, self.flow, dtype=x.dtype)
131131
if self.cached:
132132
self._cached_adj_t = edge_index
133133
else:

‎torch_geometric/nn/conv/gcn2_conv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def forward(self, x: Tensor, x_0: Tensor, edge_index: Adj,
119119
if cache is None:
120120
edge_index, edge_weight = gcn_norm( # yapf: disable
121121
edge_index, edge_weight, x.size(self.node_dim), False,
122-
self.add_self_loops, dtype=x.dtype)
122+
self.add_self_loops, self.flow, dtype=x.dtype)
123123
if self.cached:
124124
self._cached_edge_index = (edge_index, edge_weight)
125125
else:
@@ -130,7 +130,7 @@ def forward(self, x: Tensor, x_0: Tensor, edge_index: Adj,
130130
if cache is None:
131131
edge_index = gcn_norm( # yapf: disable
132132
edge_index, edge_weight, x.size(self.node_dim), False,
133-
self.add_self_loops, dtype=x.dtype)
133+
self.add_self_loops, self.flow, dtype=x.dtype)
134134
if self.cached:
135135
self._cached_adj_t = edge_index
136136
else:

‎torch_geometric/nn/conv/gcn_conv.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,25 @@
1717

1818
@torch.jit._overload
1919
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
20-
add_self_loops=True, dtype=None):
21-
# type: (Tensor, OptTensor, Optional[int], bool, bool, Optional[int]) -> PairTensor # noqa
20+
add_self_loops=True, flow="source_to_target", dtype=None):
21+
# type: (Tensor, OptTensor, Optional[int], bool, bool, str, Optional[int]) -> PairTensor # noqa
2222
pass
2323

2424

2525
@torch.jit._overload
2626
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
27-
add_self_loops=True, dtype=None):
28-
# type: (SparseTensor, OptTensor, Optional[int], bool, bool, Optional[int]) -> SparseTensor # noqa
27+
add_self_loops=True, flow="source_to_target", dtype=None):
28+
# type: (SparseTensor, OptTensor, Optional[int], bool, bool, str, Optional[int]) -> SparseTensor # noqa
2929
pass
3030

3131

3232
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
33-
add_self_loops=True, dtype=None):
33+
add_self_loops=True, flow="source_to_target", dtype=None):
3434

3535
fill_value = 2. if improved else 1.
3636

3737
if isinstance(edge_index, SparseTensor):
38+
assert flow in ["source_to_target"]
3839
adj_t = edge_index
3940
if not adj_t.has_value():
4041
adj_t = adj_t.fill_value(1., dtype=dtype)
@@ -48,6 +49,7 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
4849
return adj_t
4950

5051
else:
52+
assert flow in ["source_to_target", "target_to_source"]
5153
num_nodes = maybe_num_nodes(edge_index, num_nodes)
5254

5355
if edge_weight is None:
@@ -61,7 +63,8 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
6163
edge_weight = tmp_edge_weight
6264

6365
row, col = edge_index[0], edge_index[1]
64-
deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
66+
idx = col if flow == "source_to_target" else row
67+
deg = scatter_add(edge_weight, idx, dim=0, dim_size=num_nodes)
6568
deg_inv_sqrt = deg.pow_(-0.5)
6669
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
6770
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
@@ -171,7 +174,7 @@ def forward(self, x: Tensor, edge_index: Adj,
171174
if cache is None:
172175
edge_index, edge_weight = gcn_norm( # yapf: disable
173176
edge_index, edge_weight, x.size(self.node_dim),
174-
self.improved, self.add_self_loops)
177+
self.improved, self.add_self_loops, self.flow)
175178
if self.cached:
176179
self._cached_edge_index = (edge_index, edge_weight)
177180
else:
@@ -182,7 +185,7 @@ def forward(self, x: Tensor, edge_index: Adj,
182185
if cache is None:
183186
edge_index = gcn_norm( # yapf: disable
184187
edge_index, edge_weight, x.size(self.node_dim),
185-
self.improved, self.add_self_loops)
188+
self.improved, self.add_self_loops, self.flow)
186189
if self.cached:
187190
self._cached_adj_t = edge_index
188191
else:

‎torch_geometric/nn/conv/gin_conv.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,9 @@ class GINEConv(MessagePassing):
130130
- **output:** node features :math:`(|\mathcal{V}|, F_{out})` or
131131
:math:`(|\mathcal{V}_t|, F_{out})` if bipartite
132132
"""
133-
def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False,
134-
edge_dim: Optional[int] = None, **kwargs):
133+
def __init__(self, nn: torch.nn.Module, eps: float = 0.,
134+
train_eps: bool = False, edge_dim: Optional[int] = None,
135+
**kwargs):
135136
kwargs.setdefault('aggr', 'add')
136137
super().__init__(**kwargs)
137138
self.nn = nn
@@ -141,11 +142,16 @@ def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False,
141142
else:
142143
self.register_buffer('eps', torch.Tensor([eps]))
143144
if edge_dim is not None:
144-
if hasattr(self.nn[0], 'in_features'):
145-
in_channels = self.nn[0].in_features
145+
if isinstance(self.nn, torch.nn.Sequential):
146+
nn = self.nn[0]
147+
if hasattr(nn, 'in_features'):
148+
in_channels = nn.in_features
149+
elif hasattr(nn, 'in_channels'):
150+
in_channels = nn.in_channels
146151
else:
147-
in_channels = self.nn[0].in_channels
152+
raise ValueError("Could not infer input channels from `nn`.")
148153
self.lin = Linear(edge_dim, in_channels)
154+
149155
else:
150156
self.lin = None
151157
self.reset_parameters()

‎torch_geometric/nn/conv/lg_conv.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,12 @@ def forward(self, x: Tensor, edge_index: Adj,
4242
""""""
4343
if self.normalize and isinstance(edge_index, Tensor):
4444
out = gcn_norm(edge_index, edge_weight, x.size(self.node_dim),
45-
add_self_loops=False, dtype=x.dtype)
45+
add_self_loops=False, flow=self.flow, dtype=x.dtype)
4646
edge_index, edge_weight = out
4747
elif self.normalize and isinstance(edge_index, SparseTensor):
4848
edge_index = gcn_norm(edge_index, None, x.size(self.node_dim),
49-
add_self_loops=False, dtype=x.dtype)
49+
add_self_loops=False, flow=self.flow,
50+
dtype=x.dtype)
5051

5152
# propagate_type: (x: Tensor, edge_weight: OptTensor)
5253
return self.propagate(edge_index, x=x, edge_weight=edge_weight,

‎torch_geometric/nn/conv/pdn_conv.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,11 @@ def forward(self, x: Tensor, edge_index: Adj,
100100
if isinstance(edge_index, Tensor):
101101
edge_index, edge_attr = gcn_norm(edge_index, edge_attr,
102102
x.size(self.node_dim), False,
103-
self.add_self_loops)
103+
self.add_self_loops,
104+
self.flow)
104105
elif isinstance(edge_index, SparseTensor):
105106
edge_index = gcn_norm(edge_index, None, x.size(self.node_dim),
106-
False, self.add_self_loops)
107+
False, self.add_self_loops, self.flow)
107108

108109
x = self.lin(x)
109110

‎torch_geometric/nn/conv/sg_conv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,11 @@ def forward(self, x: Tensor, edge_index: Adj,
8383
if isinstance(edge_index, Tensor):
8484
edge_index, edge_weight = gcn_norm( # yapf: disable
8585
edge_index, edge_weight, x.size(self.node_dim), False,
86-
self.add_self_loops, dtype=x.dtype)
86+
self.add_self_loops, self.flow, dtype=x.dtype)
8787
elif isinstance(edge_index, SparseTensor):
8888
edge_index = gcn_norm( # yapf: disable
8989
edge_index, edge_weight, x.size(self.node_dim), False,
90-
self.add_self_loops, dtype=x.dtype)
90+
self.add_self_loops, self.flow, dtype=x.dtype)
9191

9292
for k in range(self.K):
9393
# propagate_type: (x: Tensor, edge_weight: OptTensor)

‎torch_geometric/nn/conv/tag_conv.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,13 @@ def forward(self, x: Tensor, edge_index: Adj,
7575
if isinstance(edge_index, Tensor):
7676
edge_index, edge_weight = gcn_norm( # yapf: disable
7777
edge_index, edge_weight, x.size(self.node_dim),
78-
improved=False, add_self_loops=False, dtype=x.dtype)
78+
improved=False, add_self_loops=False, flow=self.flow,
79+
dtype=x.dtype)
7980

8081
elif isinstance(edge_index, SparseTensor):
8182
edge_index = gcn_norm( # yapf: disable
8283
edge_index, edge_weight, x.size(self.node_dim),
83-
add_self_loops=False, dtype=x.dtype)
84+
add_self_loops=False, flow=self.flow, dtype=x.dtype)
8485

8586
out = self.lins[0](x)
8687
for lin in self.lins[1:]:

0 commit comments

Comments
 (0)
Please sign in to comment.