Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GNNExplainer improvements #6065

Merged
merged 10 commits into from
Nov 28, 2022
Merged

GNNExplainer improvements #6065

merged 10 commits into from
Nov 28, 2022

Conversation

rusty1s
Copy link
Member

@rusty1s rusty1s commented Nov 25, 2022

  • Drops subgraph computation. GNNExplainer is now always applied on the full set of edges. This fixes several bugs, e.g., that index now references invalid entries. Instead, we compute the hard mask to get all edges involved in message passing.
  • GNNExplainer now correctly works with a tensor of index
  • target_index clean-up.
  • Moves model.eval() to the base class and resets it state afterwards.

@codecov
Copy link

codecov bot commented Nov 25, 2022

Codecov Report

Merging #6065 (f513b66) into master (0fdf935) will increase coverage by 0.02%.
The diff coverage is 93.40%.

❗ Current head f513b66 differs from pull request most recent head 219d595. Consider uploading reports for the commit 219d595 to get more accurate results

@@            Coverage Diff             @@
##           master    #6065      +/-   ##
==========================================
+ Coverage   84.28%   84.30%   +0.02%     
==========================================
  Files         362      362              
  Lines       20471    20444      -27     
==========================================
- Hits        17254    17236      -18     
+ Misses       3217     3208       -9     
Impacted Files Coverage Δ
...rch_geometric/explain/algorithm/dummy_explainer.py 94.44% <ø> (+3.96%) ⬆️
torch_geometric/explain/algorithm/gnn_explainer.py 95.68% <88.88%> (+1.20%) ⬆️
torch_geometric/explain/algorithm/base.py 96.61% <100.00%> (+9.37%) ⬆️
torch_geometric/explain/explainer.py 100.00% <100.00%> (ø)

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

Copy link
Member

@wsad1 wsad1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this implementation is very clean, thanks.
We might get incorrect results if the nodes in index have overlapping subgraphs. Because some edges/nodes might get more importance as they are part of multiple subgraphs.

torch_geometric/explain/explainer.py Show resolved Hide resolved
torch_geometric/explain/algorithm/gnn_explainer.py Outdated Show resolved Hide resolved
@rusty1s rusty1s self-assigned this Nov 26, 2022
Copy link
Contributor

@RBendias RBendias left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! :) As Jinu suggested, I also think it's cleaner if the configs (explainer and model) are e.g. class variables and not passed through multiple methods, but we can do that in a future PR.

examples/gnn_explainer_ba_shapes.py Show resolved Hide resolved
subgraph_edge_mask = subgraph[3][non_loop_mask]
targets.append(data.edge_label[subgraph_edge_mask].cpu())
preds.append(expl_edge_mask[subgraph_edge_mask].cpu())
_, _, _, hard_edge_mask = k_hop_subgraph(node_index, num_hops=3,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would assume that the edges between the 3-hop nodes and the 4-hop nodes have some importance, due to the GCN normalization, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4-hop nodes do not have any importance as the layer is only 3 layer deep.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But isn't the degree calculated with the information on the 4 hop nodes? And, therefore, the output would change if we would remove all the 4 hop nodes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but the 4 hop nodes do not have any importance score as these edge masks are never trained.

@rusty1s rusty1s merged commit a8bfa49 into master Nov 28, 2022
@rusty1s rusty1s deleted the gnn_explainer_refactor branch November 28, 2022 18:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants