Skip to content

Commit

Permalink
Subgraph Visualization of GNN Explanations (#6235)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Dec 19, 2022
1 parent 9fa50ef commit 7a8ea61
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 15 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ coverage.xml
*.out
*.pt
*.onnx
examples/feature_importance.png
examples/*.png
examples/*.pdf

!torch_geometric/data/
!test/data/
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.3.0] - 2023-MM-DD
### Added
- Added subgraph visualization of GNN explanations ([#6235](https://github.com/pyg-team/pytorch_geometric/pull/6235))
- Added weighted negative sampling option in `LinkNeighborLoader` ([#6264](https://github.com/pyg-team/pytorch_geometric/pull/6264))
- Added the `BA2MotifDataset` explainer dataset ([#6257](https://github.com/pyg-team/pytorch_geometric/pull/6257))
- Added `CycleMotif` motif generator to generate `n`-node cycle shaped motifs ([#6256](https://github.com/pyg-team/pytorch_geometric/pull/6256))
Expand Down
8 changes: 5 additions & 3 deletions examples/gnn_explainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os.path as osp

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -57,6 +56,9 @@ def forward(self, x, edge_index):
print(f'Generated explanations in {explanation.available_explanations}')

path = 'feature_importance.png'
ax = explanation.visualize_feature_importance()
plt.savefig(path)
explanation.visualize_feature_importance(path, top_k=10)
print(f"Feature importance plot has been saved to '{path}'")

path = 'subgraph.pdf'
explanation.visualize_graph(path)
print(f"Subgraph visualization plot has been saved to '{path}'")
14 changes: 10 additions & 4 deletions test/explain/test_explanation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import os.path
import random
import sys

import pytest
import torch

Expand Down Expand Up @@ -155,10 +159,12 @@ def test_visualize_feature_importance(data, top_k, node_feat_mask):
node_feat_mask=node_feat_mask,
)

path = os.path.join('/', 'tmp', f'{random.randrange(sys.maxsize)}.png')

if not node_feat_mask:
with pytest.raises(ValueError, match="node_feat_mask' is not"):
explanation.visualize_feature_importance(top_k=top_k)
explanation.visualize_feature_importance(path, top_k=top_k)
else:
ax = explanation.visualize_feature_importance(top_k=top_k)
num_feats_plotted = top_k if top_k is not None else data.num_features
assert len(ax.yaxis.get_ticklabels()) == num_feats_plotted
explanation.visualize_feature_importance(path, top_k=top_k)
assert os.path.exists(path)
os.remove(path)
43 changes: 37 additions & 6 deletions torch_geometric/explain/explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch_geometric.data.hetero_data import HeteroData
from torch_geometric.explain.config import ThresholdConfig, ThresholdType
from torch_geometric.typing import EdgeType, NodeType
from torch_geometric.visualization import visualize_graph


class ExplanationMixin:
Expand Down Expand Up @@ -195,18 +196,21 @@ def _apply_masks(

def visualize_feature_importance(
self,
path: Optional[str] = None,
feat_labels: Optional[List[str]] = None,
top_k: Optional[int] = None,
):
r"""Creates a bar plot of the node features importance by summing up
:attr:`self.node_feat_mask` across all nodes.
Args:
path (str, optional): The path to where the plot is saved.
If set to :obj:`None`, will visualize the plot on-the-fly.
(default: :obj:`None`)
feat_labels (List[str], optional): Optional labels for features.
(default :obj:`None`)
top_k (int, optional): Top k features to plot. If :obj:`None`
plots all features. (default: :obj:`None`)
:rtype: :class:`matplotlib.axes.Axes`
"""
import matplotlib.pyplot as plt
import pandas as pd
Expand All @@ -229,21 +233,48 @@ def visualize_feature_importance(
df = pd.DataFrame({'feat_importance': feat_importance},
index=feat_labels)
df = df.sort_values("feat_importance", ascending=False)
df = df.head(top_k) if top_k is not None else df
df = df.round(decimals=3)

if top_k is not None:
df = df.head(top_k)
title = f"Feature importance for top {len(df)} features"
else:
title = f"Feature importance for {len(df)} features"

ax = df.plot(
kind='barh',
figsize=(10, 7),
title=f"Feature importance for top {len(df)} features",
xlabel='Feature importance',
ylabel='Feature label',
title=title,
xlabel='Feature label',
xlim=[0, float(feat_importance.max()) + 0.3],
legend=False,
)
plt.gca().invert_yaxis()
ax.bar_label(container=ax.containers[0], label_type='edge')

return ax
if path is not None:
plt.savefig(path)
else:
plt.show()

plt.close()

def visualize_graph(self, path: Optional[str] = None,
backend: Optional[str] = None):
r"""Visualizes the explanation graph with edge opacity corresponding to
edge importance.
Args:
path (str, optional): The path to where the plot is saved.
If set to :obj:`None`, will visualize the plot on-the-fly.
(default: :obj:`None`)
backend (str, optional): The graph drawing backend to use for
visualization (:obj:`"graphviz"`, :obj:`"networkx"`).
If set to :obj:`None`, will use the most appropriate
visualization backend based on available system packages.
(default: :obj:`None`)
"""
visualize_graph(self.edge_index, self.edge_mask, path, backend)


class HeteroExplanation(HeteroData, ExplanationMixin):
Expand Down
6 changes: 5 additions & 1 deletion torch_geometric/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .graph import visualize_graph
from .influence import influence

__all__ = ['influence']
__all__ = [
'visualize_graph',
'influence',
]
140 changes: 140 additions & 0 deletions torch_geometric/visualization/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from math import sqrt
from typing import Any, Optional

import torch
from torch import Tensor

BACKENDS = {'graphviz', 'networkx'}


def has_graphviz() -> bool:
try:
import graphviz
except ImportError:
return False

try:
graphviz.Digraph().pipe()
except graphviz.backend.ExecutableNotFound:
return False

return True


def visualize_graph(
edge_index: Tensor,
edge_weight: Optional[Tensor] = None,
path: Optional[str] = None,
backend: Optional[str] = None,
) -> Any:
r"""Visualizes the graph given via :obj:`edge_index` and (optional)
:obj:`edge_weight`.
Args:
edge_index (torch.Tensor): The edge indices.
edge_weight (torch.Tensor, optional): The edge weights.
path (str, optional): The path to where the plot is saved.
If set to :obj:`None`, will visualize the plot on-the-fly.
(default: :obj:`None`)
backend (str, optional): The graph drawing backend to use for
visualization (:obj:`"graphviz"`, :obj:`"networkx"`).
If set to :obj:`None`, will use the most appropriate
visualization backend based on available system packages.
(default: :obj:`None`)
"""
if edge_weight is not None: # Normalize edge weights.
edge_weight = edge_weight - edge_weight.min()
edge_weight = edge_weight / edge_weight.max()

if edge_weight is not None: # Discard any edges with zero edge weight:
mask = edge_weight > 1e-7
edge_index = edge_index[:, mask]
edge_weight = edge_weight[mask]

if edge_weight is None:
edge_weight = torch.ones(edge_index.size(1))

if backend is None:
backend = 'graphviz' if has_graphviz() else 'networkx'

if backend.lower() == 'networkx':
return _visualize_graph_via_networkx(edge_index, edge_weight, path)
elif backend.lower() == 'graphviz':
return _visualize_graph_via_graphviz(edge_index, edge_weight, path)

raise ValueError(f"Expected graph drawing backend to be in "
f"{BACKENDS} (got '{backend}')")


def _visualize_graph_via_graphviz(
edge_index: Tensor,
edge_weight: Tensor,
path: Optional[str] = None,
) -> Any:
import graphviz

suffix = path.split('.')[-1] if path is not None else None
g = graphviz.Digraph('graph', format=suffix)
g.attr('node', shape='circle', fontsize='11pt')

for node in edge_index.view(-1).unique().tolist():
g.node(str(node))

for (src, dst), w in zip(edge_index.t().tolist(), edge_weight.tolist()):
hex_color = hex(255 - round(255 * w))[2:]
hex_color = f'{hex_color}0' if len(hex_color) == 1 else hex_color
g.edge(str(src), str(dst), color=f'#{hex_color}{hex_color}{hex_color}')

if path is not None:
path = '.'.join(path.split('.')[:-1])
g.render(path, cleanup=True)
else:
g.view()

return g


def _visualize_graph_via_networkx(
edge_index: Tensor,
edge_weight: Tensor,
path: Optional[str] = None,
) -> Any:
import matplotlib.pyplot as plt
import networkx as nx

g = nx.DiGraph()
node_size = 800

for node in edge_index.view(-1).unique().tolist():
g.add_node(node)

for (src, dst), w in zip(edge_index.t().tolist(), edge_weight.tolist()):
g.add_edge(src, dst, alpha=w)

ax = plt.gca()
pos = nx.spring_layout(g)
for src, dst, data in g.edges(data=True):
ax.annotate(
'',
xy=pos[src],
xytext=pos[dst],
arrowprops=dict(
arrowstyle="->",
alpha=data['alpha'],
shrinkA=sqrt(node_size) / 2.0,
shrinkB=sqrt(node_size) / 2.0,
connectionstyle="arc3,rad=0.1",
),
)

nodes = nx.draw_networkx_nodes(g, pos, node_size=node_size,
node_color='white', margins=0.1)
nodes.set_edgecolor('black')
nx.draw_networkx_labels(g, pos, font_size=10)

if path is not None:
plt.savefig(path)
else:
plt.show()

plt.close()

0 comments on commit 7a8ea61

Please sign in to comment.