-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Some usability updates to
Explanation
(#6054)
### Motivation When evaluating and visualizing explanations it will be very convenient to obtain the explanation and explanation complement subgraphs directly from the `Explanation` objects. Given that the class holds information of the original graph being explained this shouldn’t be difficult to achieve. I propose we add two methods to the `Explanation` class: 1. `get_explanation_subgraph` returns the explanation subgraph $G_S$, 2. `get_complement_subgraph` returns the complement of the explanation subgraph $G_{C \backslash S}$ I’ve submitted this draft PR to get opinions on this, both implementation-wise and if we should do this in the first place. ### Alternatives 1. Delegate this to other parts of code, I.e. whenewer this is needed (e.g. when evaluation explanations or visualizing them) 2. Implement the logic in `util.py` for general Data object, e.g. a `mask_graph` method, which can then be reused here or elsewhere. Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
- Loading branch information
1 parent
656fed9
commit 0fdf935
Showing
7 changed files
with
279 additions
and
113 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import pytest | ||
import torch | ||
|
||
from torch_geometric.data import Data | ||
from torch_geometric.explain import Explanation | ||
|
||
|
||
@pytest.fixture | ||
def data(): | ||
return Data( | ||
x=torch.randn(4, 3), | ||
edge_index=torch.tensor([ | ||
[0, 0, 0, 1, 1, 2], | ||
[1, 2, 3, 2, 3, 3], | ||
]), | ||
edge_attr=torch.randn(6, 3), | ||
) | ||
|
||
|
||
def create_random_explanation( | ||
data: Data, | ||
node_mask: bool = True, | ||
edge_mask: bool = True, | ||
node_feat_mask: bool = True, | ||
edge_feat_mask: bool = True, | ||
): | ||
node_mask = torch.rand(data.x.size(0)) if node_mask else None | ||
edge_mask = torch.rand(data.edge_index.size(1)) if edge_mask else None | ||
node_feat_mask = torch.rand_like(data.x) if node_feat_mask else None | ||
edge_feat_mask = (torch.rand_like(data.edge_attr) | ||
if edge_feat_mask else None) | ||
|
||
return Explanation( # Create explanation. | ||
node_mask=node_mask, | ||
edge_mask=edge_mask, | ||
node_feat_mask=node_feat_mask, | ||
edge_feat_mask=edge_feat_mask, | ||
x=data.x, | ||
edge_index=data.edge_index, | ||
edge_attr=data.edge_attr, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize('node_mask', [True, False]) | ||
@pytest.mark.parametrize('edge_mask', [True, False]) | ||
@pytest.mark.parametrize('node_feat_mask', [True, False]) | ||
@pytest.mark.parametrize('edge_feat_mask', [True, False]) | ||
def test_available_explanations(data, node_mask, edge_mask, node_feat_mask, | ||
edge_feat_mask): | ||
expected = [] | ||
if node_mask: | ||
expected.append('node_mask') | ||
if edge_mask: | ||
expected.append('edge_mask') | ||
if node_feat_mask: | ||
expected.append('node_feat_mask') | ||
if edge_feat_mask: | ||
expected.append('edge_feat_mask') | ||
|
||
explanation = create_random_explanation( | ||
data, | ||
node_mask=node_mask, | ||
edge_mask=edge_mask, | ||
node_feat_mask=node_feat_mask, | ||
edge_feat_mask=edge_feat_mask, | ||
) | ||
|
||
assert set(explanation.available_explanations) == set(expected) | ||
|
||
|
||
def test_validate_explanation(data): | ||
explanation = create_random_explanation(data) | ||
explanation.validate(raise_on_error=True) | ||
|
||
with pytest.raises(ValueError, match="with 5 nodes"): | ||
explanation = create_random_explanation(data) | ||
explanation.x = torch.randn(5, 5) | ||
explanation.validate(raise_on_error=True) | ||
|
||
with pytest.raises(ValueError, match=r"of shape \[4, 4\]"): | ||
explanation = create_random_explanation(data) | ||
explanation.x = torch.randn(4, 4) | ||
explanation.validate(raise_on_error=True) | ||
|
||
with pytest.raises(ValueError, match="with 7 edges"): | ||
explanation = create_random_explanation(data) | ||
explanation.edge_index = torch.randint(0, 4, (2, 7)) | ||
explanation.validate(raise_on_error=True) | ||
|
||
with pytest.raises(ValueError, match=r"of shape \[6, 4\]"): | ||
explanation = create_random_explanation(data) | ||
explanation.edge_attr = torch.randn(6, 4) | ||
explanation.validate(raise_on_error=True) | ||
|
||
|
||
def test_node_mask(data): | ||
node_mask = torch.tensor([1.0, 0.0, 1.0, 0.0]) | ||
|
||
explanation = Explanation( | ||
node_mask=node_mask, | ||
x=data.x, | ||
edge_index=data.edge_index, | ||
edge_attr=data.edge_attr, | ||
) | ||
explanation.validate(raise_on_error=True) | ||
|
||
out = explanation.get_explanation_subgraph() | ||
assert out.node_mask.size() == (2, ) | ||
assert (out.node_mask > 0.0).sum() == 2 | ||
assert out.x.size() == (2, 3) | ||
assert out.edge_index.size() == (2, 1) | ||
assert out.edge_attr.size() == (1, 3) | ||
|
||
out = explanation.get_complement_subgraph() | ||
assert out.node_mask.size() == (2, ) | ||
assert (out.node_mask == 0.0).sum() == 2 | ||
assert out.x.size() == (2, 3) | ||
assert out.edge_index.size() == (2, 1) | ||
assert out.edge_attr.size() == (1, 3) | ||
|
||
|
||
def test_edge_mask(data): | ||
edge_mask = torch.tensor([1.0, 0.0, 1.0, 0.0, 0.0, 1.0]) | ||
|
||
explanation = Explanation( | ||
edge_mask=edge_mask, | ||
x=data.x, | ||
edge_index=data.edge_index, | ||
edge_attr=data.edge_attr, | ||
) | ||
explanation.validate(raise_on_error=True) | ||
|
||
out = explanation.get_explanation_subgraph() | ||
assert out.x.size() == (4, 3) | ||
assert out.edge_mask.size() == (3, ) | ||
assert (out.edge_mask > 0.0).sum() == 3 | ||
assert out.edge_index.size() == (2, 3) | ||
assert out.edge_attr.size() == (3, 3) | ||
|
||
out = explanation.get_complement_subgraph() | ||
assert out.x.size() == (4, 3) | ||
assert out.edge_mask.size() == (3, ) | ||
assert (out.edge_mask == 0.0).sum() == 3 | ||
assert out.edge_index.size() == (2, 3) | ||
assert out.edge_attr.size() == (3, 3) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import copy | ||
from typing import List, Optional | ||
|
||
from torch import Tensor | ||
|
||
from torch_geometric.data.data import Data, warn_or_raise | ||
|
||
|
||
class Explanation(Data): | ||
r"""Holds all the obtained explanations of a homogenous graph. | ||
The explanation object is a :obj:`~torch_geometric.data.Data` object and | ||
can hold node-attribution, edge-attribution, feature-attribution. It can | ||
also hold the original graph if needed. | ||
Args: | ||
node_mask (Tensor, optional): Node-level mask with shape | ||
:obj:`[num_nodes]`. (default: :obj:`None`) | ||
edge_mask (Tensor, optional): Edge-level mask with shape | ||
:obj:`[num_edges]`. (default: :obj:`None`) | ||
node_feat_mask (Tensor, optional): Node-level feature mask with shape | ||
:obj:`[num_nodes, num_node_features]`. (default: :obj:`None`) | ||
edge_feat_mask (Tensor, optional): Edge-level feature mask with shape | ||
:obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) | ||
**kwargs (optional): Additional attributes. | ||
""" | ||
def __init__( | ||
self, | ||
node_mask: Optional[Tensor] = None, | ||
edge_mask: Optional[Tensor] = None, | ||
node_feat_mask: Optional[Tensor] = None, | ||
edge_feat_mask: Optional[Tensor] = None, | ||
**kwargs, | ||
): | ||
super().__init__( | ||
node_mask=node_mask, | ||
edge_mask=edge_mask, | ||
node_feat_mask=node_feat_mask, | ||
edge_feat_mask=edge_feat_mask, | ||
**kwargs, | ||
) | ||
|
||
@property | ||
def available_explanations(self) -> List[str]: | ||
"""Returns the available explanation masks.""" | ||
return [ | ||
key for key in self.keys | ||
if key.endswith('_mask') and self[key] is not None | ||
] | ||
|
||
def validate(self, raise_on_error: bool = True) -> bool: | ||
r"""Validates the correctness of the explanation""" | ||
status = super().validate() | ||
|
||
if 'node_mask' in self and self.num_nodes != self.node_mask.size(0): | ||
status = False | ||
warn_or_raise( | ||
f"Expected a 'node_mask' with {self.num_nodes} nodes " | ||
f"(got {self.node_mask.size(0)} nodes)", raise_on_error) | ||
|
||
if 'edge_mask' in self and self.num_edges != self.edge_mask.size(0): | ||
status = False | ||
warn_or_raise( | ||
f"Expected an 'edge_mask' with {self.num_edges} edges " | ||
f"(got {self.edge_mask.size(0)} edges)", raise_on_error) | ||
|
||
if 'node_feat_mask' in self: | ||
if 'x' in self and self.x.size() != self.node_feat_mask.size(): | ||
status = False | ||
warn_or_raise( | ||
f"Expected a 'node_feat_mask' of shape " | ||
f"{list(self.x.size())} (got shape " | ||
f"{list(self.node_feat_mask.size())})", raise_on_error) | ||
elif self.num_nodes != self.node_feat_mask.size(0): | ||
status = False | ||
warn_or_raise( | ||
f"Expected a 'node_feat_mask' with {self.num_nodes} nodes " | ||
f"(got {self.node_feat_mask.size(0)} nodes)", | ||
raise_on_error) | ||
|
||
if 'edge_feat_mask' in self: | ||
if ('edge_attr' in self | ||
and self.edge_attr.size() != self.edge_feat_mask.size()): | ||
status = False | ||
warn_or_raise( | ||
f"Expected an 'edge_feat_mask' of shape " | ||
f"{list(self.edge_attr.size())} (got shape " | ||
f"{list(self.edge_feat_mask.size())})", raise_on_error) | ||
elif self.num_edges != self.edge_feat_mask.size(0): | ||
status = False | ||
warn_or_raise( | ||
f"Expected an 'edge_feat_mask' with {self.num_edges} " | ||
f"edges (got {self.edge_feat_mask.size(0)} edges)", | ||
raise_on_error) | ||
|
||
return status | ||
|
||
def get_explanation_subgraph(self) -> 'Explanation': | ||
r"""Returns the induced subgraph, in which all nodes and edges with | ||
zero attribution are masked out.""" | ||
return self._apply_masks( | ||
node_mask=self.node_mask > 0 if 'node_mask' in self else None, | ||
edge_mask=self.edge_mask > 0 if 'edge_mask' in self else None, | ||
) | ||
|
||
def get_complement_subgraph(self) -> 'Explanation': | ||
r"""Returns the induced subgraph, in which all nodes and edges with any | ||
attribution are masked out.""" | ||
return self._apply_masks( | ||
node_mask=self.node_mask == 0 if 'node_mask' in self else None, | ||
edge_mask=self.edge_mask == 0 if 'edge_mask' in self else None, | ||
) | ||
|
||
def _apply_masks( | ||
self, | ||
node_mask: Optional[Tensor] = None, | ||
edge_mask: Optional[Tensor] = None, | ||
) -> 'Explanation': | ||
out = copy.copy(self) | ||
|
||
if edge_mask is not None: | ||
for key, value in self.items(): | ||
if key == 'edge_index': | ||
out.edge_index = value[:, edge_mask] | ||
elif self.is_edge_attr(key): | ||
out[key] = value[edge_mask] | ||
|
||
if node_mask is not None: | ||
out = out.subgraph(node_mask) | ||
|
||
return out |
Oops, something went wrong.