Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HeteroConv aggregation cannot handle tuple outputs #9677

Open
jongyaoY opened this issue Sep 23, 2024 · 1 comment
Open

HeteroConv aggregation cannot handle tuple outputs #9677

jongyaoY opened this issue Sep 23, 2024 · 1 comment
Labels

Comments

@jongyaoY
Copy link

🐛 Describe the bug

I would like to use HeteroConv to wrap my custom message passing layer, which updates not just the node features but also edge attributes and returns a tuple of (x_updated, edge_attr_updated):

from torch_geometric.nn import HeteroConv, MessagePassing
from torch_scatter import scatter

class InteractionNetwork(MessagePassing):
    def __init__(self, ...

    def forward(self, x, edge_index, edge_attr):
        r_dim = 1
        if isinstance(edge_attr, tuple):
            edge_attr = edge_attr[r_dim]
        edge_attr_updated, aggr = self.propagate(
            x=x, edge_index=edge_index, edge_attr=edge_attr
        )
        x_updated = self.node_fn(torch.cat((x[r_dim], aggr), dim=-1))
        return x[r_dim] + x_updated, edge_attr + edge_attr_updated

    def message(self, x_i, x_j, edge_attr):  # receiver  # sender
        e_latent = torch.cat((x_i, x_j, edge_attr), dim=-1)
        return self.edge_fn(e_latent)

    def aggregate(
        self, inputs: torch.Tensor, index: torch.Tensor, dim_size=None
    ):
        out = scatter(
            inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="sum"
        )
        return inputs, out

layer = HeteroConv(
    {
        EdgeType1: InteractionNetwork(**params),
        EdgeType2: InteractionNetwork(**params),
    },
    aggr=None,
)

...

out = layer(data.x_dict, data.edge_index_dict, data.edge_attr_dict)

But the group function raises following error:

    def group(xs: List[Tensor], aggr: Optional[str]) -> Optional[Tensor]:
        if len(xs) == 0:
            return None
        elif aggr is None:
>           return torch.stack(xs, dim=1)
E           TypeError: expected Tensor as element 0 in argument 0, but got tuple

torch_geometric/nn/conv/hetero_conv.py:18: TypeError

Is it possible to make this group function or even the aggregation process also customizable?

Versions

PyTorch version: 1.12.1+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.4 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1 
CMake version: version 3.21.1
Libc version: glibc-2.31

Python version: 3.8.10 (default, Mar 15 2022, 12:22:08)  [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-71-generic-x86_64-with-glibc2.29
Is CUDA available: True

[pip3] torch_geometric==2.4.0
@jongyaoY jongyaoY added the bug label Sep 23, 2024
@zechengz
Copy link
Member

Will take a look

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants