From 0fdf9358b308baa2144765382e53c15e232fdced Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bla=C5=BE=20Stojanovi=C4=8D?= Date: Mon, 28 Nov 2022 14:38:13 +0000 Subject: [PATCH] Some usability updates to `Explanation` (#6054) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Motivation When evaluating and visualizing explanations it will be very convenient to obtain the explanation and explanation complement subgraphs directly from the `Explanation` objects. Given that the class holds information of the original graph being explained this shouldn’t be difficult to achieve. I propose we add two methods to the `Explanation` class: 1. `get_explanation_subgraph` returns the explanation subgraph $G_S$, 2. `get_complement_subgraph` returns the complement of the explanation subgraph $G_{C \backslash S}$ I’ve submitted this draft PR to get opinions on this, both implementation-wise and if we should do this in the first place. ### Alternatives 1. Delegate this to other parts of code, I.e. whenewer this is needed (e.g. when evaluation explanations or visualizing them) 2. Implement the logic in `util.py` for general Data object, e.g. a `mask_graph` method, which can then be reused here or elsewhere. Co-authored-by: rusty1s --- CHANGELOG.md | 2 +- test/explain/test_explanation.py | 145 ++++++++++++++++++ test/explain/test_explanations.py | 62 -------- torch_geometric/explain/__init__.py | 2 +- .../explain/algorithm/gnn_explainer.py | 2 +- torch_geometric/explain/explanation.py | 131 ++++++++++++++++ torch_geometric/explain/explanations.py | 48 ------ 7 files changed, 279 insertions(+), 113 deletions(-) create mode 100644 test/explain/test_explanation.py delete mode 100644 test/explain/test_explanations.py create mode 100644 torch_geometric/explain/explanation.py delete mode 100644 torch_geometric/explain/explanations.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a28155f991fb..89a2c195e65c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -71,7 +71,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added missing test labels in `HGBDataset` ([#5233](https://github.com/pyg-team/pytorch_geometric/pull/5233)) - Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240)) - Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222)) -- Added `torch_geometric.explain` module with base functionality for explainability methods ([#5804](https://github.com/pyg-team/pytorch_geometric/pull/5804)) +- Added `torch_geometric.explain` module with base functionality for explainability methods ([#5804](https://github.com/pyg-team/pytorch_geometric/pull/5804), [#6054](https://github.com/pyg-team/pytorch_geometric/pull/6054)) ### Changed - Moved and adapted `GNNExplainer` from `torch_geometric.nn` to `torch_geometric.explain.algorithm` ([#5967](https://github.com/pyg-team/pytorch_geometric/pull/5967)) - Optimized scatter implementations for CPU/GPU, both with and without backward computation ([#6051](https://github.com/pyg-team/pytorch_geometric/pull/6051), [#6052](https://github.com/pyg-team/pytorch_geometric/pull/6052)) diff --git a/test/explain/test_explanation.py b/test/explain/test_explanation.py new file mode 100644 index 000000000000..a482cbeb2ae4 --- /dev/null +++ b/test/explain/test_explanation.py @@ -0,0 +1,145 @@ +import pytest +import torch + +from torch_geometric.data import Data +from torch_geometric.explain import Explanation + + +@pytest.fixture +def data(): + return Data( + x=torch.randn(4, 3), + edge_index=torch.tensor([ + [0, 0, 0, 1, 1, 2], + [1, 2, 3, 2, 3, 3], + ]), + edge_attr=torch.randn(6, 3), + ) + + +def create_random_explanation( + data: Data, + node_mask: bool = True, + edge_mask: bool = True, + node_feat_mask: bool = True, + edge_feat_mask: bool = True, +): + node_mask = torch.rand(data.x.size(0)) if node_mask else None + edge_mask = torch.rand(data.edge_index.size(1)) if edge_mask else None + node_feat_mask = torch.rand_like(data.x) if node_feat_mask else None + edge_feat_mask = (torch.rand_like(data.edge_attr) + if edge_feat_mask else None) + + return Explanation( # Create explanation. + node_mask=node_mask, + edge_mask=edge_mask, + node_feat_mask=node_feat_mask, + edge_feat_mask=edge_feat_mask, + x=data.x, + edge_index=data.edge_index, + edge_attr=data.edge_attr, + ) + + +@pytest.mark.parametrize('node_mask', [True, False]) +@pytest.mark.parametrize('edge_mask', [True, False]) +@pytest.mark.parametrize('node_feat_mask', [True, False]) +@pytest.mark.parametrize('edge_feat_mask', [True, False]) +def test_available_explanations(data, node_mask, edge_mask, node_feat_mask, + edge_feat_mask): + expected = [] + if node_mask: + expected.append('node_mask') + if edge_mask: + expected.append('edge_mask') + if node_feat_mask: + expected.append('node_feat_mask') + if edge_feat_mask: + expected.append('edge_feat_mask') + + explanation = create_random_explanation( + data, + node_mask=node_mask, + edge_mask=edge_mask, + node_feat_mask=node_feat_mask, + edge_feat_mask=edge_feat_mask, + ) + + assert set(explanation.available_explanations) == set(expected) + + +def test_validate_explanation(data): + explanation = create_random_explanation(data) + explanation.validate(raise_on_error=True) + + with pytest.raises(ValueError, match="with 5 nodes"): + explanation = create_random_explanation(data) + explanation.x = torch.randn(5, 5) + explanation.validate(raise_on_error=True) + + with pytest.raises(ValueError, match=r"of shape \[4, 4\]"): + explanation = create_random_explanation(data) + explanation.x = torch.randn(4, 4) + explanation.validate(raise_on_error=True) + + with pytest.raises(ValueError, match="with 7 edges"): + explanation = create_random_explanation(data) + explanation.edge_index = torch.randint(0, 4, (2, 7)) + explanation.validate(raise_on_error=True) + + with pytest.raises(ValueError, match=r"of shape \[6, 4\]"): + explanation = create_random_explanation(data) + explanation.edge_attr = torch.randn(6, 4) + explanation.validate(raise_on_error=True) + + +def test_node_mask(data): + node_mask = torch.tensor([1.0, 0.0, 1.0, 0.0]) + + explanation = Explanation( + node_mask=node_mask, + x=data.x, + edge_index=data.edge_index, + edge_attr=data.edge_attr, + ) + explanation.validate(raise_on_error=True) + + out = explanation.get_explanation_subgraph() + assert out.node_mask.size() == (2, ) + assert (out.node_mask > 0.0).sum() == 2 + assert out.x.size() == (2, 3) + assert out.edge_index.size() == (2, 1) + assert out.edge_attr.size() == (1, 3) + + out = explanation.get_complement_subgraph() + assert out.node_mask.size() == (2, ) + assert (out.node_mask == 0.0).sum() == 2 + assert out.x.size() == (2, 3) + assert out.edge_index.size() == (2, 1) + assert out.edge_attr.size() == (1, 3) + + +def test_edge_mask(data): + edge_mask = torch.tensor([1.0, 0.0, 1.0, 0.0, 0.0, 1.0]) + + explanation = Explanation( + edge_mask=edge_mask, + x=data.x, + edge_index=data.edge_index, + edge_attr=data.edge_attr, + ) + explanation.validate(raise_on_error=True) + + out = explanation.get_explanation_subgraph() + assert out.x.size() == (4, 3) + assert out.edge_mask.size() == (3, ) + assert (out.edge_mask > 0.0).sum() == 3 + assert out.edge_index.size() == (2, 3) + assert out.edge_attr.size() == (3, 3) + + out = explanation.get_complement_subgraph() + assert out.x.size() == (4, 3) + assert out.edge_mask.size() == (3, ) + assert (out.edge_mask == 0.0).sum() == 3 + assert out.edge_index.size() == (2, 3) + assert out.edge_attr.size() == (3, 3) diff --git a/test/explain/test_explanations.py b/test/explain/test_explanations.py deleted file mode 100644 index 49ccdceb67db..000000000000 --- a/test/explain/test_explanations.py +++ /dev/null @@ -1,62 +0,0 @@ -import pytest -import torch - -from torch_geometric.data import Data -from torch_geometric.explain import Explanation - - -@pytest.fixture -def data(): - return Data( - x=torch.randn(10, 5), - edge_index=torch.randint(0, 10, (2, 20)), - edge_attr=torch.randn(20, 3), - ) - - -def create_random_explanation( - data: Data, - node_mask: bool = True, - edge_mask: bool = True, - node_feat_mask: bool = True, - edge_feat_mask: bool = True, -): - node_mask = torch.rand(data.x.size(0)) if node_mask else None - edge_mask = torch.rand(data.edge_index.size(1)) if edge_mask else None - node_feat_mask = torch.rand_like(data.x) if node_feat_mask else None - edge_feat_mask = (torch.rand_like(data.edge_attr) - if edge_feat_mask else None) - - return Explanation( # Create explanation. - node_mask=node_mask, - edge_mask=edge_mask, - node_feat_mask=node_feat_mask, - edge_feat_mask=edge_feat_mask, - ) - - -@pytest.mark.parametrize('node_mask', [True, False]) -@pytest.mark.parametrize('edge_mask', [True, False]) -@pytest.mark.parametrize('node_feat_mask', [True, False]) -@pytest.mark.parametrize('edge_feat_mask', [True, False]) -def test_available_explanations(data, node_mask, edge_mask, node_feat_mask, - edge_feat_mask): - expected = [] - if node_mask: - expected.append('node_mask') - if edge_mask: - expected.append('edge_mask') - if node_feat_mask: - expected.append('node_feat_mask') - if edge_feat_mask: - expected.append('edge_feat_mask') - - explanation = create_random_explanation( - data, - node_mask=node_mask, - edge_mask=edge_mask, - node_feat_mask=node_feat_mask, - edge_feat_mask=edge_feat_mask, - ) - - assert set(explanation.available_explanations) == set(expected) diff --git a/torch_geometric/explain/__init__.py b/torch_geometric/explain/__init__.py index 5075b48b8fe1..5e7b420664e8 100644 --- a/torch_geometric/explain/__init__.py +++ b/torch_geometric/explain/__init__.py @@ -1,5 +1,5 @@ from .config import ExplainerConfig, ModelConfig, ThresholdConfig -from .explanations import Explanation +from .explanation import Explanation from .algorithm import * # noqa from .explainer import Explainer diff --git a/torch_geometric/explain/algorithm/gnn_explainer.py b/torch_geometric/explain/algorithm/gnn_explainer.py index f819bbbba0da..e820bc746706 100644 --- a/torch_geometric/explain/algorithm/gnn_explainer.py +++ b/torch_geometric/explain/algorithm/gnn_explainer.py @@ -6,6 +6,7 @@ from torch import Tensor from torch.nn.parameter import Parameter +from torch_geometric.explain import Explanation from torch_geometric.explain.algorithm.utils import clear_masks, set_masks from torch_geometric.explain.config import ( ExplainerConfig, @@ -15,7 +16,6 @@ ModelReturnType, ModelTaskLevel, ) -from torch_geometric.explain.explanations import Explanation from .base import ExplainerAlgorithm diff --git a/torch_geometric/explain/explanation.py b/torch_geometric/explain/explanation.py new file mode 100644 index 000000000000..f885bedfbb3b --- /dev/null +++ b/torch_geometric/explain/explanation.py @@ -0,0 +1,131 @@ +import copy +from typing import List, Optional + +from torch import Tensor + +from torch_geometric.data.data import Data, warn_or_raise + + +class Explanation(Data): + r"""Holds all the obtained explanations of a homogenous graph. + + The explanation object is a :obj:`~torch_geometric.data.Data` object and + can hold node-attribution, edge-attribution, feature-attribution. It can + also hold the original graph if needed. + + Args: + node_mask (Tensor, optional): Node-level mask with shape + :obj:`[num_nodes]`. (default: :obj:`None`) + edge_mask (Tensor, optional): Edge-level mask with shape + :obj:`[num_edges]`. (default: :obj:`None`) + node_feat_mask (Tensor, optional): Node-level feature mask with shape + :obj:`[num_nodes, num_node_features]`. (default: :obj:`None`) + edge_feat_mask (Tensor, optional): Edge-level feature mask with shape + :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) + **kwargs (optional): Additional attributes. + """ + def __init__( + self, + node_mask: Optional[Tensor] = None, + edge_mask: Optional[Tensor] = None, + node_feat_mask: Optional[Tensor] = None, + edge_feat_mask: Optional[Tensor] = None, + **kwargs, + ): + super().__init__( + node_mask=node_mask, + edge_mask=edge_mask, + node_feat_mask=node_feat_mask, + edge_feat_mask=edge_feat_mask, + **kwargs, + ) + + @property + def available_explanations(self) -> List[str]: + """Returns the available explanation masks.""" + return [ + key for key in self.keys + if key.endswith('_mask') and self[key] is not None + ] + + def validate(self, raise_on_error: bool = True) -> bool: + r"""Validates the correctness of the explanation""" + status = super().validate() + + if 'node_mask' in self and self.num_nodes != self.node_mask.size(0): + status = False + warn_or_raise( + f"Expected a 'node_mask' with {self.num_nodes} nodes " + f"(got {self.node_mask.size(0)} nodes)", raise_on_error) + + if 'edge_mask' in self and self.num_edges != self.edge_mask.size(0): + status = False + warn_or_raise( + f"Expected an 'edge_mask' with {self.num_edges} edges " + f"(got {self.edge_mask.size(0)} edges)", raise_on_error) + + if 'node_feat_mask' in self: + if 'x' in self and self.x.size() != self.node_feat_mask.size(): + status = False + warn_or_raise( + f"Expected a 'node_feat_mask' of shape " + f"{list(self.x.size())} (got shape " + f"{list(self.node_feat_mask.size())})", raise_on_error) + elif self.num_nodes != self.node_feat_mask.size(0): + status = False + warn_or_raise( + f"Expected a 'node_feat_mask' with {self.num_nodes} nodes " + f"(got {self.node_feat_mask.size(0)} nodes)", + raise_on_error) + + if 'edge_feat_mask' in self: + if ('edge_attr' in self + and self.edge_attr.size() != self.edge_feat_mask.size()): + status = False + warn_or_raise( + f"Expected an 'edge_feat_mask' of shape " + f"{list(self.edge_attr.size())} (got shape " + f"{list(self.edge_feat_mask.size())})", raise_on_error) + elif self.num_edges != self.edge_feat_mask.size(0): + status = False + warn_or_raise( + f"Expected an 'edge_feat_mask' with {self.num_edges} " + f"edges (got {self.edge_feat_mask.size(0)} edges)", + raise_on_error) + + return status + + def get_explanation_subgraph(self) -> 'Explanation': + r"""Returns the induced subgraph, in which all nodes and edges with + zero attribution are masked out.""" + return self._apply_masks( + node_mask=self.node_mask > 0 if 'node_mask' in self else None, + edge_mask=self.edge_mask > 0 if 'edge_mask' in self else None, + ) + + def get_complement_subgraph(self) -> 'Explanation': + r"""Returns the induced subgraph, in which all nodes and edges with any + attribution are masked out.""" + return self._apply_masks( + node_mask=self.node_mask == 0 if 'node_mask' in self else None, + edge_mask=self.edge_mask == 0 if 'edge_mask' in self else None, + ) + + def _apply_masks( + self, + node_mask: Optional[Tensor] = None, + edge_mask: Optional[Tensor] = None, + ) -> 'Explanation': + out = copy.copy(self) + + if edge_mask is not None: + for key, value in self.items(): + if key == 'edge_index': + out.edge_index = value[:, edge_mask] + elif self.is_edge_attr(key): + out[key] = value[edge_mask] + + if node_mask is not None: + out = out.subgraph(node_mask) + + return out diff --git a/torch_geometric/explain/explanations.py b/torch_geometric/explain/explanations.py deleted file mode 100644 index 2eb24a64a9a9..000000000000 --- a/torch_geometric/explain/explanations.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import List, Optional - -from torch import Tensor - -from torch_geometric.data import Data - - -class Explanation(Data): - r"""Holds all the obtained explanations of a homogenous graph. - - The explanation object is a :obj:`~torch_geometric.data.Data` object and - can hold node-attribution, edge-attribution, feature-attribution. It can - also hold the original graph if needed. - - Args: - node_mask (Tensor, optional): Node-level mask with shape - :obj:`[num_nodes]`. (default: :obj:`None`) - edge_mask (Tensor, optional): Edge-level mask with shape - :obj:`[num_edges]`. (default: :obj:`None`) - node_feat_mask (Tensor, optional): Node-level feature mask with shape - :obj:`[num_nodes, num_node_features]`. (default: :obj:`None`) - edge_feat_mask (Tensor, optional): Edge-level feature mask with shape - :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) - **kwargs (optional): Additional attributes. - """ - def __init__( - self, - node_mask: Optional[Tensor] = None, - edge_mask: Optional[Tensor] = None, - node_feat_mask: Optional[Tensor] = None, - edge_feat_mask: Optional[Tensor] = None, - **kwargs, - ): - super().__init__( - node_mask=node_mask, - edge_mask=edge_mask, - node_feat_mask=node_feat_mask, - edge_feat_mask=edge_feat_mask, - **kwargs, - ) - - @property - def available_explanations(self) -> List[str]: - """Returns the available explanation masks.""" - return [ - key for key in self.keys - if key.endswith('_mask') and self[key] is not None - ]