Skip to content

Commit

Permalink
Some usability updates to Explanation (#6054)
Browse files Browse the repository at this point in the history
### Motivation
When evaluating and visualizing explanations it will be very convenient
to obtain the explanation and explanation complement subgraphs directly
from the `Explanation` objects. Given that the class holds information
of the original graph being explained this shouldn’t be difficult to
achieve. I propose we add two methods to the `Explanation` class:

1. `get_explanation_subgraph` returns the explanation subgraph $G_S$,
2. `get_complement_subgraph` returns the complement of the explanation
subgraph $G_{C \backslash S}$

I’ve submitted this draft PR to get opinions on this, both
implementation-wise and if we should do this in the first place.

### Alternatives
1. Delegate this to other parts of code, I.e. whenewer this is needed
(e.g. when evaluation explanations or visualizing them)
2. Implement the logic in `util.py` for general Data object, e.g. a
`mask_graph` method, which can then be reused here or elsewhere.

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
BlazStojanovic and rusty1s authored Nov 28, 2022
1 parent 656fed9 commit 0fdf935
Show file tree
Hide file tree
Showing 7 changed files with 279 additions and 113 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added missing test labels in `HGBDataset` ([#5233](https://github.com/pyg-team/pytorch_geometric/pull/5233))
- Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240))
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
- Added `torch_geometric.explain` module with base functionality for explainability methods ([#5804](https://github.com/pyg-team/pytorch_geometric/pull/5804))
- Added `torch_geometric.explain` module with base functionality for explainability methods ([#5804](https://github.com/pyg-team/pytorch_geometric/pull/5804), [#6054](https://github.com/pyg-team/pytorch_geometric/pull/6054))
### Changed
- Moved and adapted `GNNExplainer` from `torch_geometric.nn` to `torch_geometric.explain.algorithm` ([#5967](https://github.com/pyg-team/pytorch_geometric/pull/5967))
- Optimized scatter implementations for CPU/GPU, both with and without backward computation ([#6051](https://github.com/pyg-team/pytorch_geometric/pull/6051), [#6052](https://github.com/pyg-team/pytorch_geometric/pull/6052))
Expand Down
145 changes: 145 additions & 0 deletions test/explain/test_explanation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import pytest
import torch

from torch_geometric.data import Data
from torch_geometric.explain import Explanation


@pytest.fixture
def data():
return Data(
x=torch.randn(4, 3),
edge_index=torch.tensor([
[0, 0, 0, 1, 1, 2],
[1, 2, 3, 2, 3, 3],
]),
edge_attr=torch.randn(6, 3),
)


def create_random_explanation(
data: Data,
node_mask: bool = True,
edge_mask: bool = True,
node_feat_mask: bool = True,
edge_feat_mask: bool = True,
):
node_mask = torch.rand(data.x.size(0)) if node_mask else None
edge_mask = torch.rand(data.edge_index.size(1)) if edge_mask else None
node_feat_mask = torch.rand_like(data.x) if node_feat_mask else None
edge_feat_mask = (torch.rand_like(data.edge_attr)
if edge_feat_mask else None)

return Explanation( # Create explanation.
node_mask=node_mask,
edge_mask=edge_mask,
node_feat_mask=node_feat_mask,
edge_feat_mask=edge_feat_mask,
x=data.x,
edge_index=data.edge_index,
edge_attr=data.edge_attr,
)


@pytest.mark.parametrize('node_mask', [True, False])
@pytest.mark.parametrize('edge_mask', [True, False])
@pytest.mark.parametrize('node_feat_mask', [True, False])
@pytest.mark.parametrize('edge_feat_mask', [True, False])
def test_available_explanations(data, node_mask, edge_mask, node_feat_mask,
edge_feat_mask):
expected = []
if node_mask:
expected.append('node_mask')
if edge_mask:
expected.append('edge_mask')
if node_feat_mask:
expected.append('node_feat_mask')
if edge_feat_mask:
expected.append('edge_feat_mask')

explanation = create_random_explanation(
data,
node_mask=node_mask,
edge_mask=edge_mask,
node_feat_mask=node_feat_mask,
edge_feat_mask=edge_feat_mask,
)

assert set(explanation.available_explanations) == set(expected)


def test_validate_explanation(data):
explanation = create_random_explanation(data)
explanation.validate(raise_on_error=True)

with pytest.raises(ValueError, match="with 5 nodes"):
explanation = create_random_explanation(data)
explanation.x = torch.randn(5, 5)
explanation.validate(raise_on_error=True)

with pytest.raises(ValueError, match=r"of shape \[4, 4\]"):
explanation = create_random_explanation(data)
explanation.x = torch.randn(4, 4)
explanation.validate(raise_on_error=True)

with pytest.raises(ValueError, match="with 7 edges"):
explanation = create_random_explanation(data)
explanation.edge_index = torch.randint(0, 4, (2, 7))
explanation.validate(raise_on_error=True)

with pytest.raises(ValueError, match=r"of shape \[6, 4\]"):
explanation = create_random_explanation(data)
explanation.edge_attr = torch.randn(6, 4)
explanation.validate(raise_on_error=True)


def test_node_mask(data):
node_mask = torch.tensor([1.0, 0.0, 1.0, 0.0])

explanation = Explanation(
node_mask=node_mask,
x=data.x,
edge_index=data.edge_index,
edge_attr=data.edge_attr,
)
explanation.validate(raise_on_error=True)

out = explanation.get_explanation_subgraph()
assert out.node_mask.size() == (2, )
assert (out.node_mask > 0.0).sum() == 2
assert out.x.size() == (2, 3)
assert out.edge_index.size() == (2, 1)
assert out.edge_attr.size() == (1, 3)

out = explanation.get_complement_subgraph()
assert out.node_mask.size() == (2, )
assert (out.node_mask == 0.0).sum() == 2
assert out.x.size() == (2, 3)
assert out.edge_index.size() == (2, 1)
assert out.edge_attr.size() == (1, 3)


def test_edge_mask(data):
edge_mask = torch.tensor([1.0, 0.0, 1.0, 0.0, 0.0, 1.0])

explanation = Explanation(
edge_mask=edge_mask,
x=data.x,
edge_index=data.edge_index,
edge_attr=data.edge_attr,
)
explanation.validate(raise_on_error=True)

out = explanation.get_explanation_subgraph()
assert out.x.size() == (4, 3)
assert out.edge_mask.size() == (3, )
assert (out.edge_mask > 0.0).sum() == 3
assert out.edge_index.size() == (2, 3)
assert out.edge_attr.size() == (3, 3)

out = explanation.get_complement_subgraph()
assert out.x.size() == (4, 3)
assert out.edge_mask.size() == (3, )
assert (out.edge_mask == 0.0).sum() == 3
assert out.edge_index.size() == (2, 3)
assert out.edge_attr.size() == (3, 3)
62 changes: 0 additions & 62 deletions test/explain/test_explanations.py

This file was deleted.

2 changes: 1 addition & 1 deletion torch_geometric/explain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .config import ExplainerConfig, ModelConfig, ThresholdConfig
from .explanations import Explanation
from .explanation import Explanation
from .algorithm import * # noqa
from .explainer import Explainer

Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/explain/algorithm/gnn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import Tensor
from torch.nn.parameter import Parameter

from torch_geometric.explain import Explanation
from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
from torch_geometric.explain.config import (
ExplainerConfig,
Expand All @@ -15,7 +16,6 @@
ModelReturnType,
ModelTaskLevel,
)
from torch_geometric.explain.explanations import Explanation

from .base import ExplainerAlgorithm

Expand Down
131 changes: 131 additions & 0 deletions torch_geometric/explain/explanation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import copy
from typing import List, Optional

from torch import Tensor

from torch_geometric.data.data import Data, warn_or_raise


class Explanation(Data):
r"""Holds all the obtained explanations of a homogenous graph.
The explanation object is a :obj:`~torch_geometric.data.Data` object and
can hold node-attribution, edge-attribution, feature-attribution. It can
also hold the original graph if needed.
Args:
node_mask (Tensor, optional): Node-level mask with shape
:obj:`[num_nodes]`. (default: :obj:`None`)
edge_mask (Tensor, optional): Edge-level mask with shape
:obj:`[num_edges]`. (default: :obj:`None`)
node_feat_mask (Tensor, optional): Node-level feature mask with shape
:obj:`[num_nodes, num_node_features]`. (default: :obj:`None`)
edge_feat_mask (Tensor, optional): Edge-level feature mask with shape
:obj:`[num_edges, num_edge_features]`. (default: :obj:`None`)
**kwargs (optional): Additional attributes.
"""
def __init__(
self,
node_mask: Optional[Tensor] = None,
edge_mask: Optional[Tensor] = None,
node_feat_mask: Optional[Tensor] = None,
edge_feat_mask: Optional[Tensor] = None,
**kwargs,
):
super().__init__(
node_mask=node_mask,
edge_mask=edge_mask,
node_feat_mask=node_feat_mask,
edge_feat_mask=edge_feat_mask,
**kwargs,
)

@property
def available_explanations(self) -> List[str]:
"""Returns the available explanation masks."""
return [
key for key in self.keys
if key.endswith('_mask') and self[key] is not None
]

def validate(self, raise_on_error: bool = True) -> bool:
r"""Validates the correctness of the explanation"""
status = super().validate()

if 'node_mask' in self and self.num_nodes != self.node_mask.size(0):
status = False
warn_or_raise(
f"Expected a 'node_mask' with {self.num_nodes} nodes "
f"(got {self.node_mask.size(0)} nodes)", raise_on_error)

if 'edge_mask' in self and self.num_edges != self.edge_mask.size(0):
status = False
warn_or_raise(
f"Expected an 'edge_mask' with {self.num_edges} edges "
f"(got {self.edge_mask.size(0)} edges)", raise_on_error)

if 'node_feat_mask' in self:
if 'x' in self and self.x.size() != self.node_feat_mask.size():
status = False
warn_or_raise(
f"Expected a 'node_feat_mask' of shape "
f"{list(self.x.size())} (got shape "
f"{list(self.node_feat_mask.size())})", raise_on_error)
elif self.num_nodes != self.node_feat_mask.size(0):
status = False
warn_or_raise(
f"Expected a 'node_feat_mask' with {self.num_nodes} nodes "
f"(got {self.node_feat_mask.size(0)} nodes)",
raise_on_error)

if 'edge_feat_mask' in self:
if ('edge_attr' in self
and self.edge_attr.size() != self.edge_feat_mask.size()):
status = False
warn_or_raise(
f"Expected an 'edge_feat_mask' of shape "
f"{list(self.edge_attr.size())} (got shape "
f"{list(self.edge_feat_mask.size())})", raise_on_error)
elif self.num_edges != self.edge_feat_mask.size(0):
status = False
warn_or_raise(
f"Expected an 'edge_feat_mask' with {self.num_edges} "
f"edges (got {self.edge_feat_mask.size(0)} edges)",
raise_on_error)

return status

def get_explanation_subgraph(self) -> 'Explanation':
r"""Returns the induced subgraph, in which all nodes and edges with
zero attribution are masked out."""
return self._apply_masks(
node_mask=self.node_mask > 0 if 'node_mask' in self else None,
edge_mask=self.edge_mask > 0 if 'edge_mask' in self else None,
)

def get_complement_subgraph(self) -> 'Explanation':
r"""Returns the induced subgraph, in which all nodes and edges with any
attribution are masked out."""
return self._apply_masks(
node_mask=self.node_mask == 0 if 'node_mask' in self else None,
edge_mask=self.edge_mask == 0 if 'edge_mask' in self else None,
)

def _apply_masks(
self,
node_mask: Optional[Tensor] = None,
edge_mask: Optional[Tensor] = None,
) -> 'Explanation':
out = copy.copy(self)

if edge_mask is not None:
for key, value in self.items():
if key == 'edge_index':
out.edge_index = value[:, edge_mask]
elif self.is_edge_attr(key):
out[key] = value[edge_mask]

if node_mask is not None:
out = out.subgraph(node_mask)

return out
Loading

0 comments on commit 0fdf935

Please sign in to comment.