Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: to_hetero with GCN on single node types #4279

Merged
merged 2 commits into from
Mar 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion test/nn/test_to_hetero_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.nn import Linear, ReLU, Sequential
from torch_sparse import SparseTensor

from torch_geometric.nn import BatchNorm, GINEConv
from torch_geometric.nn import BatchNorm, GCNConv, GINEConv
from torch_geometric.nn import Linear as LazyLinear
from torch_geometric.nn import MessagePassing, RGCNConv, SAGEConv, to_hetero

Expand Down Expand Up @@ -212,6 +212,34 @@ def test_to_hetero():
assert out['author'].size() == (8, 16)


class GCN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(16, 32)
self.conv2 = GCNConv(32, 64)

def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index).relu()
return x


def test_to_hetero_with_gcn():
metadata = (['paper'], [('paper', '0', 'paper'), ('paper', '1', 'paper')])
x_dict = {'paper': torch.randn(100, 16)}
edge_index_dict = {
('paper', '0', 'paper'): torch.randint(100, (2, 200)),
('paper', '1', 'paper'): torch.randint(100, (2, 200)),
}

model = GCN()
model = to_hetero(model, metadata, debug=False)
print(model)
out = model(x_dict, edge_index_dict)
assert isinstance(out, dict) and len(out) == 1
assert out['paper'].size() == (100, 64)


class GraphConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='mean')
Expand All @@ -221,6 +249,8 @@ def reset_parameters(self):
self.lin.reset_parameters()

def forward(self, x, edge_index):
if isinstance(x, Tensor):
x = (x, x)
return self.propagate(edge_index, x=(self.lin(x[0]), x[1]))


Expand Down
12 changes: 9 additions & 3 deletions torch_geometric/nn/to_hetero_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,18 @@ def map_args_kwargs(self, node: Node,
def _recurse(value: Any) -> Any:
if isinstance(value, Node):
out = self.find_by_name(f'{value.name}__{key2str(key)}')
if out is None and isinstance(key, tuple):
out = (
if out is not None:
return out
elif isinstance(key, tuple) and key[0] == key[-1]:
name = f'{value.name}__{key2str(key[0])}'
return self.find_by_name(name)
elif isinstance(key, tuple) and key[0] != key[-1]:
return (
self.find_by_name(f'{value.name}__{key2str(key[0])}'),
self.find_by_name(f'{value.name}__{key2str(key[-1])}'),
)
return out
else:
raise NotImplementedError
elif isinstance(value, dict):
return {k: _recurse(v) for k, v in value.items()}
elif isinstance(value, list):
Expand Down