diff --git a/test/nn/test_to_hetero_transformer.py b/test/nn/test_to_hetero_transformer.py index 31f111821a68..51831addc441 100644 --- a/test/nn/test_to_hetero_transformer.py +++ b/test/nn/test_to_hetero_transformer.py @@ -5,7 +5,7 @@ from torch.nn import Linear, ReLU, Sequential from torch_sparse import SparseTensor -from torch_geometric.nn import BatchNorm, GINEConv +from torch_geometric.nn import BatchNorm, GCNConv, GINEConv from torch_geometric.nn import Linear as LazyLinear from torch_geometric.nn import MessagePassing, RGCNConv, SAGEConv, to_hetero @@ -212,6 +212,34 @@ def test_to_hetero(): assert out['author'].size() == (8, 16) +class GCN(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = GCNConv(16, 32) + self.conv2 = GCNConv(32, 64) + + def forward(self, x: Tensor, edge_index: Tensor) -> Tensor: + x = self.conv1(x, edge_index).relu() + x = self.conv2(x, edge_index).relu() + return x + + +def test_to_hetero_with_gcn(): + metadata = (['paper'], [('paper', '0', 'paper'), ('paper', '1', 'paper')]) + x_dict = {'paper': torch.randn(100, 16)} + edge_index_dict = { + ('paper', '0', 'paper'): torch.randint(100, (2, 200)), + ('paper', '1', 'paper'): torch.randint(100, (2, 200)), + } + + model = GCN() + model = to_hetero(model, metadata, debug=False) + print(model) + out = model(x_dict, edge_index_dict) + assert isinstance(out, dict) and len(out) == 1 + assert out['paper'].size() == (100, 64) + + class GraphConv(MessagePassing): def __init__(self, in_channels, out_channels): super().__init__(aggr='mean') @@ -221,6 +249,8 @@ def reset_parameters(self): self.lin.reset_parameters() def forward(self, x, edge_index): + if isinstance(x, Tensor): + x = (x, x) return self.propagate(edge_index, x=(self.lin(x[0]), x[1])) diff --git a/torch_geometric/nn/to_hetero_transformer.py b/torch_geometric/nn/to_hetero_transformer.py index 742f30cb40f3..cedd26f4d8e2 100644 --- a/torch_geometric/nn/to_hetero_transformer.py +++ b/torch_geometric/nn/to_hetero_transformer.py @@ -305,12 +305,18 @@ def map_args_kwargs(self, node: Node, def _recurse(value: Any) -> Any: if isinstance(value, Node): out = self.find_by_name(f'{value.name}__{key2str(key)}') - if out is None and isinstance(key, tuple): - out = ( + if out is not None: + return out + elif isinstance(key, tuple) and key[0] == key[-1]: + name = f'{value.name}__{key2str(key[0])}' + return self.find_by_name(name) + elif isinstance(key, tuple) and key[0] != key[-1]: + return ( self.find_by_name(f'{value.name}__{key2str(key[0])}'), self.find_by_name(f'{value.name}__{key2str(key[-1])}'), ) - return out + else: + raise NotImplementedError elif isinstance(value, dict): return {k: _recurse(v) for k, v in value.items()} elif isinstance(value, list):