Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GCN wrong code for SparseTensor #10047

Open
Qin87 opened this issue Feb 19, 2025 · 0 comments
Open

GCN wrong code for SparseTensor #10047

Qin87 opened this issue Feb 19, 2025 · 0 comments
Labels

Comments

@Qin87
Copy link

Qin87 commented Feb 19, 2025

🐛 Describe the bug

GCNConv official code considers edge_index input to be both Tensor and SparseTensor. But their code is wrong for SparseTensor.

Here is the wrong code:
def forward(self, x: Tensor, edge_index: Adj,
edge_weight: OptTensor = None) -> Tensor:

    if isinstance(x, (tuple, list)):
        raise ValueError(f"'{self.__class__.__name__}' received a tuple "
                         f"of node features as input while this layer "
                         f"does not support bipartite message passing. "
                         f"Please try other layers such as 'SAGEConv' or "
                         f"'GraphConv' instead")

    if self.normalize:
        if isinstance(edge_index, Tensor):
            cache = self._cached_edge_index
            if cache is None:
                edge_index, edge_weight = gcn_norm(  # yapf: disable
                    edge_index, edge_weight, x.size(self.node_dim),
                    self.improved, self.add_self_loops, self.flow, x.dtype)
                if self.cached:
                    self._cached_edge_index = (edge_index, edge_weight)
            else:
                edge_index, edge_weight = cache[0], cache[1]

        elif isinstance(edge_index, SparseTensor):
            cache = self._cached_adj_t
            if cache is None:
                edge_index = gcn_norm(  # yapf: disable
                    edge_index, edge_weight, x.size(self.node_dim),
                    self.improved, self.add_self_loops, self.flow, x.dtype)
                if self.cached:
                    self._cached_adj_t = edge_index
            else:
                edge_index = cache

    x = self.lin(x)

    # propagate_type: (x: Tensor, edge_weight: OptTensor)
    out = self.propagate(edge_index, x=x, edge_weight=edge_weight)

    if self.bias is not None:
        out = out + self.bias

    return out

For SparseTensor, they didn't unwrap the SparseTensor back into Tensor edge_index and edge_weight, leading to wrong input for this code:
out = self.propagate(edge_index, x=x, edge_weight=edge_weight). Consequently, the experimental result will be wrong.

Here is my revised version for the SparseTensor case:

        elif isinstance(edge_index, SparseTensor):
            cache = self._cached_adj_t
            if cache is None:
                sparse_tensor = gcn_norm(  # yapf: disable
                    edge_index, edge_weight, x.size(self.node_dim),
                    self.improved, self.add_self_loops, self.flow, x.dtype)

                # Extract edge_index and edge_weight from SparseTensor
                row, col, edge_weight = sparse_tensor.coo()
                edge_index = torch.stack([row, col], dim=0)
                if self.cached:
                        self._cached_adj_t = (edge_index, edge_weight)
            else:
                edge_index, edge_weight = cache

Versions

all versions are wrong in this part

@Qin87 Qin87 added the bug label Feb 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant