From 29b2745dbad76c40bb2fa154a23da1533a2d4030 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 16 Mar 2022 17:21:30 +0000 Subject: [PATCH 1/2] fix --- test/nn/test_to_hetero_transformer.py | 32 ++++++++++++++++++++- torch_geometric/nn/to_hetero_transformer.py | 12 ++++++-- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/test/nn/test_to_hetero_transformer.py b/test/nn/test_to_hetero_transformer.py index 31f111821a68..51831addc441 100644 --- a/test/nn/test_to_hetero_transformer.py +++ b/test/nn/test_to_hetero_transformer.py @@ -5,7 +5,7 @@ from torch.nn import Linear, ReLU, Sequential from torch_sparse import SparseTensor -from torch_geometric.nn import BatchNorm, GINEConv +from torch_geometric.nn import BatchNorm, GCNConv, GINEConv from torch_geometric.nn import Linear as LazyLinear from torch_geometric.nn import MessagePassing, RGCNConv, SAGEConv, to_hetero @@ -212,6 +212,34 @@ def test_to_hetero(): assert out['author'].size() == (8, 16) +class GCN(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = GCNConv(16, 32) + self.conv2 = GCNConv(32, 64) + + def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: + x = self.conv1(x, edge_index).relu() + x = self.conv2(x, edge_index).relu() + return x + + +def test_to_hetero_with_gcn(): + metadata = (['paper'], [('paper', '0', 'paper'), ('paper', '1', 'paper')]) + x_dict = {'paper': torch.randn(100, 16)} + edge_index_dict = { + ('paper', '0', 'paper'): torch.randint(100, (2, 200)), + ('paper', '1', 'paper'): torch.randint(100, (2, 200)), + } + + model = GCN() + model = to_hetero(model, metadata, debug=False) + print(model) + out = model(x_dict, edge_index_dict) + assert isinstance(out, dict) and len(out) == 1 + assert out['paper'].size() == (100, 64) + + class GraphConv(MessagePassing): def __init__(self, in_channels, out_channels): super().__init__(aggr='mean') @@ -221,6 +249,8 @@ def reset_parameters(self): self.lin.reset_parameters() def forward(self, x, edge_index): + if isinstance(x, Tensor): + x = (x, x) return self.propagate(edge_index, x=(self.lin(x[0]), x[1])) diff --git a/torch_geometric/nn/to_hetero_transformer.py b/torch_geometric/nn/to_hetero_transformer.py index 742f30cb40f3..6f4ed3d09e81 100644 --- a/torch_geometric/nn/to_hetero_transformer.py +++ b/torch_geometric/nn/to_hetero_transformer.py @@ -305,12 +305,18 @@ def map_args_kwargs(self, node: Node, def _recurse(value: Any) -> Any: if isinstance(value, Node): out = self.find_by_name(f'{value.name}__{key2str(key)}') - if out is None and isinstance(key, tuple): - out = ( + if out is not None: + return out + elif isinstance(key, tuple) and key[0] == key[-1]: + return self.find_by_name( + f'{value.name}__{key2str(key[0])}') + elif isinstance(key, tuple) and key[0] != key[-1]: + return ( self.find_by_name(f'{value.name}__{key2str(key[0])}'), self.find_by_name(f'{value.name}__{key2str(key[-1])}'), ) - return out + else: + raise NotImplementedError elif isinstance(value, dict): return {k: _recurse(v) for k, v in value.items()} elif isinstance(value, list): From d9356e60373ba4abda8ca9fcec9dc2ac4169482b Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 16 Mar 2022 17:25:10 +0000 Subject: [PATCH 2/2] typo --- torch_geometric/nn/to_hetero_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_geometric/nn/to_hetero_transformer.py b/torch_geometric/nn/to_hetero_transformer.py index 6f4ed3d09e81..cedd26f4d8e2 100644 --- a/torch_geometric/nn/to_hetero_transformer.py +++ b/torch_geometric/nn/to_hetero_transformer.py @@ -308,8 +308,8 @@ def _recurse(value: Any) -> Any: if out is not None: return out elif isinstance(key, tuple) and key[0] == key[-1]: - return self.find_by_name( - f'{value.name}__{key2str(key[0])}') + name = f'{value.name}__{key2str(key[0])}' + return self.find_by_name(name) elif isinstance(key, tuple) and key[0] != key[-1]: return ( self.find_by_name(f'{value.name}__{key2str(key[0])}'),