Skip to content

Commit

Permalink
Fix: to_hetero with GCN on single node types (#4279)
Browse files Browse the repository at this point in the history
* fix

* typo
  • Loading branch information
rusty1s authored Mar 16, 2022
1 parent 7526c8b commit 57c88c0
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
32 changes: 31 additions & 1 deletion test/nn/test_to_hetero_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.nn import Linear, ReLU, Sequential
from torch_sparse import SparseTensor

from torch_geometric.nn import BatchNorm, GINEConv
from torch_geometric.nn import BatchNorm, GCNConv, GINEConv
from torch_geometric.nn import Linear as LazyLinear
from torch_geometric.nn import MessagePassing, RGCNConv, SAGEConv, to_hetero

Expand Down Expand Up @@ -212,6 +212,34 @@ def test_to_hetero():
assert out['author'].size() == (8, 16)


class GCN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(16, 32)
self.conv2 = GCNConv(32, 64)

def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index).relu()
return x


def test_to_hetero_with_gcn():
metadata = (['paper'], [('paper', '0', 'paper'), ('paper', '1', 'paper')])
x_dict = {'paper': torch.randn(100, 16)}
edge_index_dict = {
('paper', '0', 'paper'): torch.randint(100, (2, 200)),
('paper', '1', 'paper'): torch.randint(100, (2, 200)),
}

model = GCN()
model = to_hetero(model, metadata, debug=False)
print(model)
out = model(x_dict, edge_index_dict)
assert isinstance(out, dict) and len(out) == 1
assert out['paper'].size() == (100, 64)


class GraphConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='mean')
Expand All @@ -221,6 +249,8 @@ def reset_parameters(self):
self.lin.reset_parameters()

def forward(self, x, edge_index):
if isinstance(x, Tensor):
x = (x, x)
return self.propagate(edge_index, x=(self.lin(x[0]), x[1]))


Expand Down
12 changes: 9 additions & 3 deletions torch_geometric/nn/to_hetero_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,18 @@ def map_args_kwargs(self, node: Node,
def _recurse(value: Any) -> Any:
if isinstance(value, Node):
out = self.find_by_name(f'{value.name}__{key2str(key)}')
if out is None and isinstance(key, tuple):
out = (
if out is not None:
return out
elif isinstance(key, tuple) and key[0] == key[-1]:
name = f'{value.name}__{key2str(key[0])}'
return self.find_by_name(name)
elif isinstance(key, tuple) and key[0] != key[-1]:
return (
self.find_by_name(f'{value.name}__{key2str(key[0])}'),
self.find_by_name(f'{value.name}__{key2str(key[-1])}'),
)
return out
else:
raise NotImplementedError
elif isinstance(value, dict):
return {k: _recurse(v) for k, v in value.items()}
elif isinstance(value, list):
Expand Down

0 comments on commit 57c88c0

Please sign in to comment.