Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some usability updates to Explanation #6054

Merged
merged 10 commits into from
Nov 28, 2022
Merged
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added missing test labels in `HGBDataset` ([#5233](https://github.com/pyg-team/pytorch_geometric/pull/5233))
- Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240))
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
- Added `torch_geometric.explain` module with base functionality for explainability methods ([#5804](https://github.com/pyg-team/pytorch_geometric/pull/5804))
- Added `torch_geometric.explain` module with base functionality for explainability methods ([#5804](https://github.com/pyg-team/pytorch_geometric/pull/5804), [#6054](https://github.com/pyg-team/pytorch_geometric/pull/6054))
### Changed
- Moved and adapted `GNNExplainer` from `torch_geometric.nn` to `torch_geometric.explain.algorithm` ([#5967](https://github.com/pyg-team/pytorch_geometric/pull/5967))
- Optimized scatter implementations for CPU/GPU, both with and without backward computation ([#6051](https://github.com/pyg-team/pytorch_geometric/pull/6051), [#6052](https://github.com/pyg-team/pytorch_geometric/pull/6052))
Expand Down
145 changes: 145 additions & 0 deletions test/explain/test_explanation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import pytest
import torch

from torch_geometric.data import Data
from torch_geometric.explain import Explanation


@pytest.fixture
def data():
return Data(
x=torch.randn(4, 3),
edge_index=torch.tensor([
[0, 0, 0, 1, 1, 2],
[1, 2, 3, 2, 3, 3],
]),
edge_attr=torch.randn(6, 3),
)


def create_random_explanation(
data: Data,
node_mask: bool = True,
edge_mask: bool = True,
node_feat_mask: bool = True,
edge_feat_mask: bool = True,
):
node_mask = torch.rand(data.x.size(0)) if node_mask else None
edge_mask = torch.rand(data.edge_index.size(1)) if edge_mask else None
node_feat_mask = torch.rand_like(data.x) if node_feat_mask else None
edge_feat_mask = (torch.rand_like(data.edge_attr)
if edge_feat_mask else None)

return Explanation( # Create explanation.
node_mask=node_mask,
edge_mask=edge_mask,
node_feat_mask=node_feat_mask,
edge_feat_mask=edge_feat_mask,
x=data.x,
edge_index=data.edge_index,
edge_attr=data.edge_attr,
)


@pytest.mark.parametrize('node_mask', [True, False])
@pytest.mark.parametrize('edge_mask', [True, False])
@pytest.mark.parametrize('node_feat_mask', [True, False])
@pytest.mark.parametrize('edge_feat_mask', [True, False])
def test_available_explanations(data, node_mask, edge_mask, node_feat_mask,
edge_feat_mask):
expected = []
if node_mask:
expected.append('node_mask')
if edge_mask:
expected.append('edge_mask')
if node_feat_mask:
expected.append('node_feat_mask')
if edge_feat_mask:
expected.append('edge_feat_mask')

explanation = create_random_explanation(
data,
node_mask=node_mask,
edge_mask=edge_mask,
node_feat_mask=node_feat_mask,
edge_feat_mask=edge_feat_mask,
)

assert set(explanation.available_explanations) == set(expected)


def test_validate_explanation(data):
explanation = create_random_explanation(data)
explanation.validate(raise_on_error=True)

with pytest.raises(ValueError, match="with 5 nodes"):
explanation = create_random_explanation(data)
explanation.x = torch.randn(5, 5)
explanation.validate(raise_on_error=True)

with pytest.raises(ValueError, match=r"of shape \[4, 4\]"):
explanation = create_random_explanation(data)
explanation.x = torch.randn(4, 4)
explanation.validate(raise_on_error=True)

with pytest.raises(ValueError, match="with 7 edges"):
explanation = create_random_explanation(data)
explanation.edge_index = torch.randint(0, 4, (2, 7))
explanation.validate(raise_on_error=True)

with pytest.raises(ValueError, match=r"of shape \[6, 4\]"):
explanation = create_random_explanation(data)
explanation.edge_attr = torch.randn(6, 4)
explanation.validate(raise_on_error=True)


def test_node_mask(data):
node_mask = torch.tensor([1.0, 0.0, 1.0, 0.0])

explanation = Explanation(
node_mask=node_mask,
x=data.x,
edge_index=data.edge_index,
edge_attr=data.edge_attr,
)
explanation.validate(raise_on_error=True)

out = explanation.get_explanation_subgraph()
assert out.node_mask.size() == (2, )
assert (out.node_mask > 0.0).sum() == 2
assert out.x.size() == (2, 3)
assert out.edge_index.size() == (2, 1)
assert out.edge_attr.size() == (1, 3)

out = explanation.get_complement_subgraph()
assert out.node_mask.size() == (2, )
assert (out.node_mask == 0.0).sum() == 2
assert out.x.size() == (2, 3)
assert out.edge_index.size() == (2, 1)
assert out.edge_attr.size() == (1, 3)


def test_edge_mask(data):
edge_mask = torch.tensor([1.0, 0.0, 1.0, 0.0, 0.0, 1.0])

explanation = Explanation(
edge_mask=edge_mask,
x=data.x,
edge_index=data.edge_index,
edge_attr=data.edge_attr,
)
explanation.validate(raise_on_error=True)

out = explanation.get_explanation_subgraph()
assert out.x.size() == (4, 3)
assert out.edge_mask.size() == (3, )
assert (out.edge_mask > 0.0).sum() == 3
assert out.edge_index.size() == (2, 3)
assert out.edge_attr.size() == (3, 3)

out = explanation.get_complement_subgraph()
assert out.x.size() == (4, 3)
assert out.edge_mask.size() == (3, )
assert (out.edge_mask == 0.0).sum() == 3
assert out.edge_index.size() == (2, 3)
assert out.edge_attr.size() == (3, 3)
62 changes: 0 additions & 62 deletions test/explain/test_explanations.py

This file was deleted.

2 changes: 1 addition & 1 deletion torch_geometric/explain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .config import ExplainerConfig, ModelConfig, ThresholdConfig
from .explanations import Explanation
from .explanation import Explanation
from .algorithm import * # noqa
from .explainer import Explainer

Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/explain/algorithm/gnn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import Tensor
from torch.nn.parameter import Parameter

from torch_geometric.explain import Explanation
from torch_geometric.explain.algorithm.utils import clear_masks, set_masks
from torch_geometric.explain.config import (
ExplainerConfig,
Expand All @@ -15,7 +16,6 @@
ModelReturnType,
ModelTaskLevel,
)
from torch_geometric.explain.explanations import Explanation

from .base import ExplainerAlgorithm

Expand Down
131 changes: 131 additions & 0 deletions torch_geometric/explain/explanation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import copy
from typing import List, Optional

from torch import Tensor

from torch_geometric.data.data import Data, warn_or_raise


class Explanation(Data):
r"""Holds all the obtained explanations of a homogenous graph.

The explanation object is a :obj:`~torch_geometric.data.Data` object and
can hold node-attribution, edge-attribution, feature-attribution. It can
also hold the original graph if needed.

Args:
node_mask (Tensor, optional): Node-level mask with shape
:obj:`[num_nodes]`. (default: :obj:`None`)
edge_mask (Tensor, optional): Edge-level mask with shape
:obj:`[num_edges]`. (default: :obj:`None`)
node_feat_mask (Tensor, optional): Node-level feature mask with shape
:obj:`[num_nodes, num_node_features]`. (default: :obj:`None`)
edge_feat_mask (Tensor, optional): Edge-level feature mask with shape
:obj:`[num_edges, num_edge_features]`. (default: :obj:`None`)
**kwargs (optional): Additional attributes.
"""
def __init__(
self,
node_mask: Optional[Tensor] = None,
edge_mask: Optional[Tensor] = None,
node_feat_mask: Optional[Tensor] = None,
edge_feat_mask: Optional[Tensor] = None,
**kwargs,
):
super().__init__(
node_mask=node_mask,
edge_mask=edge_mask,
node_feat_mask=node_feat_mask,
edge_feat_mask=edge_feat_mask,
**kwargs,
)

@property
def available_explanations(self) -> List[str]:
"""Returns the available explanation masks."""
return [
key for key in self.keys
if key.endswith('_mask') and self[key] is not None
]

def validate(self, raise_on_error: bool = True) -> bool:
r"""Validates the correctness of the explanation"""
status = super().validate()

if 'node_mask' in self and self.num_nodes != self.node_mask.size(0):
status = False
warn_or_raise(
f"Expected a 'node_mask' with {self.num_nodes} nodes "
f"(got {self.node_mask.size(0)} nodes)", raise_on_error)

if 'edge_mask' in self and self.num_edges != self.edge_mask.size(0):
status = False
warn_or_raise(
f"Expected an 'edge_mask' with {self.num_edges} edges "
f"(got {self.edge_mask.size(0)} edges)", raise_on_error)

if 'node_feat_mask' in self:
if 'x' in self and self.x.size() != self.node_feat_mask.size():
status = False
warn_or_raise(
f"Expected a 'node_feat_mask' of shape "
f"{list(self.x.size())} (got shape "
f"{list(self.node_feat_mask.size())})", raise_on_error)
elif self.num_nodes != self.node_feat_mask.size(0):
status = False
warn_or_raise(
f"Expected a 'node_feat_mask' with {self.num_nodes} nodes "
f"(got {self.node_feat_mask.size(0)} nodes)",
raise_on_error)

if 'edge_feat_mask' in self:
if ('edge_attr' in self
and self.edge_attr.size() != self.edge_feat_mask.size()):
status = False
warn_or_raise(
f"Expected an 'edge_feat_mask' of shape "
f"{list(self.edge_attr.size())} (got shape "
f"{list(self.edge_feat_mask.size())})", raise_on_error)
elif self.num_edges != self.edge_feat_mask.size(0):
status = False
warn_or_raise(
f"Expected an 'edge_feat_mask' with {self.num_edges} "
f"edges (got {self.edge_feat_mask.size(0)} edges)",
raise_on_error)

return status

def get_explanation_subgraph(self) -> 'Explanation':
r"""Returns the induced subgraph, in which all nodes and edges with
zero attribution are masked out."""
return self._apply_masks(
node_mask=self.node_mask > 0 if 'node_mask' in self else None,
edge_mask=self.edge_mask > 0 if 'edge_mask' in self else None,
)

def get_complement_subgraph(self) -> 'Explanation':
r"""Returns the induced subgraph, in which all nodes and edges with any
attribution are masked out."""
return self._apply_masks(
node_mask=self.node_mask == 0 if 'node_mask' in self else None,
edge_mask=self.edge_mask == 0 if 'edge_mask' in self else None,
)

def _apply_masks(
self,
node_mask: Optional[Tensor] = None,
edge_mask: Optional[Tensor] = None,
) -> 'Explanation':
out = copy.copy(self)

if edge_mask is not None:
for key, value in self.items():
if key == 'edge_index':
out.edge_index = value[:, edge_mask]
elif self.is_edge_attr(key):
out[key] = value[edge_mask]

if node_mask is not None:
out = out.subgraph(node_mask)

return out
Loading