Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: GCNConv in to_hetero with single node type #4271

Closed
fratajcz opened this issue Mar 16, 2022 · 10 comments · Fixed by #4279
Closed

Bug: GCNConv in to_hetero with single node type #4271

fratajcz opened this issue Mar 16, 2022 · 10 comments · Fixed by #4279

Comments

@fratajcz
Copy link

🐛 Describe the bug

I just started using heterogeneous graphs, so far I have roughly 1 year experience with homogeneous graphs using your library. I have a graph with just 1 node type ("gene") but multiple edge types between the nodes. My HeteroData object prints as follows:

HeteroData(
  gene={
    x=[16433, 93],
    y=[16433],
    train_mask=[16433],
    test_mask=[16433],
    val_mask=False
  },
  (gene, BioPlex30HCT116, gene)={ edge_index=[2, 95370] },
  (gene, BioPlex30293T, gene)={ edge_index=[2, 156370] },
  (gene, HuRI, gene)={ edge_index=[2, 74488] }
)

I convert my GCN-based model (that runs fine with a homogeneous Data object) using the to_hetero() function.

I then pass my data into the model as follows:

out = model(data.x_dict, data.edge_index_dict)

where data is the HeteroData Object described above.

However, I get an error and I think it is because it expects edge_weight somewhere:

Traceback (most recent call last):
    out = model(data.x_dict, data.edge_index_dict)
  File "/home/fratajcz/anaconda3/envs/compat/lib/python3.7/site-packages/torch/fx/graph_module.py", line 308, in wrapped_call
    return cls_call(self, *args, **kwargs)
  File "/home/fratajcz/anaconda3/envs/compat/lib/python3.7/site-packages/torch/fx/graph_module.py", line 308, in wrapped_call
    return cls_call(self, *args, **kwargs)
  File "/home/fratajcz/anaconda3/envs/compat/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "<eval_with_key_1>", line 7, in forward
    edge_weight__gene__BioPlex30HCT116__gene = edge_weight.get(('gene', 'BioPlex30HCT116', 'gene'))
AttributeError: 'NoneType' object has no attribute 'get'

Do I have to specify edge weights in heterogeneous graphs? The Documentation doesnt mention it.
It is weird that the traceback does mention only torch packages and not torch_geometric.

Cheers and thanks,
Florin

Environment

  • PyG version: 2.0.1
  • PyTorch version: 1.8.0
  • OS: CentOS Linux 7
  • Python version: 3.7
  • CUDA/cuDNN version: irrelevant
  • How you installed PyTorch and PyG (conda, pip, source): conda
  • Any other relevant information (e.g., version of torch-scatter): torch_scatter: 2.0.8
@fratajcz fratajcz added the bug label Mar 16, 2022
@rusty1s
Copy link
Member

rusty1s commented Mar 16, 2022

In general, edge_weight shouldn't be required so there might be something weird going on here. How does your homogeneous model look like?

@fratajcz
Copy link
Author

Thanks, I just found the problem. The forward pass of my model had an if clause that checked if edge_weight is None:

# apply graph convolutions
for i in range(len(self.gcnconv)):
    x = self.norm[i](x)
    if edge_weight is None:
        x = self.gcnconv[i](x, edge_index)
    else:
       x = self.gcnconv[i](x, edge_index, edge_weight)
     x = F.elu(x)

which obviously tripped to_heterointo thinking it would always have an edge_weight which always was None.
When i print this model after calling to_hetero, the forward pass starts with:

def forward(self, x, edge_index, edge_weight = None):
    x__gene = x.get('gene');  x = None
    edge_index__gene__BioPlex30HCT116__gene = edge_index.get(('gene', 'BioPlex30HCT116', 'gene'))
    edge_index__gene__BioPlex30293T__gene = edge_index.get(('gene', 'BioPlex30293T', 'gene'))
    edge_index__gene__HuRI__gene = edge_index.get(('gene', 'HuRI', 'gene'));  edge_index = None
    edge_weight__gene__BioPlex30HCT116__gene = edge_weight.get(('gene', 'BioPlex30HCT116', 'gene'))
    edge_weight__gene__BioPlex30293T__gene = edge_weight.get(('gene', 'BioPlex30293T', 'gene'))
    edge_weight__gene__HuRI__gene = edge_weight.get(('gene', 'HuRI', 'gene'));  edge_weight = None

However, if I change the forward pass of my model to

# apply graph convolutions
for i in range(len(self.gcnconv)):
    x = self.norm[i](x)
    x = self.gcnconv[i](x, edge_index)
    x = F.elu(x)

it works and the three lines regarding the edge_weight are gone.

However, now I run into the

File "/home/ifratajcz/anaconda3/envs/compat/lib/python3.7/site-packages/torch_geometric/nn/conv/gcn_conv.py", line 163, in forward
    edge_index, edge_weight, x.size(self.node_dim),
AttributeError: 'tuple' object has no attribute 'size'

Problem because I use GCN layers in a heterogeneous setting. I thought it should work since I have only one node type, but apparently it doesnt.

@rusty1s
Copy link
Member

rusty1s commented Mar 16, 2022

This is definitely a bug. I will try to fix it.

@rusty1s rusty1s changed the title Problem with (missing) edge_weight in Heterogeneous GCN Bug: GCNConv in to_hetero with single node type Mar 16, 2022
@rusty1s rusty1s self-assigned this Mar 16, 2022
@rusty1s rusty1s linked a pull request Mar 16, 2022 that will close this issue
@rusty1s
Copy link
Member

rusty1s commented Mar 16, 2022

I just fixed this in #4279 :)

@fratajcz
Copy link
Author

Thanks :)

@aayyad89
Copy link

aayyad89 commented Nov 26, 2022

I am still getting the same bug "AttributeError: 'tuple' object has no attribute 'size'" with the GCNConv

class LightGCN(torch.nn.Module):
    def __init__(self,
                 num_users,
                 num_movies, n_layers=2, 
                 embedding_dim=20):
        super().__init__() 
        
        self.num_users, self.num_movies = num_users, num_movies
        self.emb_dim = embedding_dim
        self.n_layers = n_layers
        self.user_emb = torch.nn.Embedding(num_users, embedding_dim, max_norm=1.0)
        self.movie_emb = torch.nn.Embedding(num_movies, embedding_dim, max_norm=1.0)

        self.n_layers = n_layers
        graph_layer = GCNConv(embedding_dim, embedding_dim,
                                   bias=False, add_self_loops=False, normalize=True)

        self.conv = HeteroConv({('user', 'rates', 'movie'): graph_layer,
                                ('movie', 'rev_rates', 'user'): graph_layer})
        
        
    def forward(self, data):
        
        node_id_dict = data.node_id_dict
        edge_index_dict = data.edge_index_dict
        

        user_emb = self.user_emb(node_id_dict['user'])
        movie_emb = self.movie_emb(node_id_dict['movie'])
        emb_dict_init = {'user': user_emb, 'movie': movie_emb}


        emb_dict = emb_dict_init
        embs = []

        for i in range(self.n_layers):
            emb_dict = self.conv(emb_dict, edge_index_dict)
            embs.append(emb_dict)

        emb_final_user = torch.stack([emb['user'] for emb in embs], -1).mean(dim=-1)
        emb_final_movie = torch.stack([emb['movie'] for emb in embs], -1).mean(dim=-1)
    
        return emb_final_user, emb_final_movie, emb_dict_init

Doesn't happen when I switch to GATConv

@rusty1s
Copy link
Member

rusty1s commented Nov 26, 2022

Yes, you can only use GCNConv for passing messages to the same node type, e.g.:

HeteroConv({
    ...
    ('movie', 'is_similar', 'movie'): GCNConv(...),
    ...
})

This is a limitation of GCNConv as it does not support bipartite message passing.

@aayyad89
Copy link

Thanks, I did not get that.

@kajocina
Copy link

@rusty1s is this just a matter of code implementation or theoretically it cannot be applied in such scenarios?

@rusty1s
Copy link
Member

rusty1s commented Jan 17, 2024

I would say it is just a limitation of the operator. In particular since

  • symmetric normalization is not well-defined on heterogeneous graphs
  • it shared a single weight matrix for neighbor features and node features

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants