Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XAI Inverse Estimation #337

Open
wants to merge 53 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
b593012
updated two metrics and base implementation for xai inverse estimation
annahedstroem Oct 9, 2023
492ea8d
finished aggregation of the inverse implementation
annahedstroem Oct 9, 2023
15efaf6
wip inverse estimation
annahedstroem Nov 15, 2023
83c6554
fixes to inverse estimation method
annahedstroem Nov 24, 2023
ba7ab51
fixes, tests and implementation
annahedstroem Dec 1, 2023
79cf6d8
automerge issues
annahedstroem Dec 7, 2023
fb3d5ef
tests passing, most todos removed
annahedstroem Dec 7, 2023
1125384
minor fixes to InverseEstimation class
annahedstroem Dec 7, 2023
6cc1cda
Update inverse_estimation.py
annahedstroem Feb 23, 2024
0a640fb
Merge branch 'main' into xai-inverse-estimation
annahedstroem Feb 23, 2024
279fb38
remove inverse_estimation edits in region_perturbation.py
annahedstroem Feb 23, 2024
a14c109
name update region_perturbation.py
annahedstroem Feb 23, 2024
c4032d9
remove metric init requirements in inverse_estimation.py
annahedstroem Feb 23, 2024
dd74095
added a ormalise_func_kwargs attribute of base class, was missing
annahedstroem Feb 23, 2024
4b763b8
added second inverse method
annahedstroem Feb 23, 2024
86d7615
added second inverse method -v2
annahedstroem Feb 23, 2024
d9cb2f3
enable batching
annahedstroem Feb 23, 2024
07c9925
Merge branch 'main' into xai-inverse-estimation
annahedstroem Mar 1, 2024
6aeba15
update method name
annahedstroem Mar 15, 2024
02e3d5a
added inverse_method as an arg
annahedstroem Mar 15, 2024
4da473a
batch update
annahedstroem Mar 15, 2024
3c123c7
updated inverse with batch
annahedstroem Mar 15, 2024
3ccf004
Merge branch 'xai-inverse-estimation' of https://github.com/understan…
annahedstroem Mar 15, 2024
05efd24
added wrapper
annahedstroem Mar 15, 2024
3235507
added wrapper
annahedstroem Mar 15, 2024
f415b53
Merge branch 'main' into xai-inverse-estimation
annahedstroem Mar 15, 2024
4a7764f
Merge branch 'main' into xai-inverse-estimation
annahedstroem Mar 24, 2024
9b1de09
Merge branch 'main' into xai-inverse-estimation
annahedstroem Mar 24, 2024
3ec97cc
Merge branch 'main' into xai-inverse-estimation
annahedstroem Mar 24, 2024
dfcdae4
update from True to False in assert in inverse_estimation.py
annahedstroem Mar 24, 2024
3d06f90
bugfix a_batch shape inverse_estimation.py
annahedstroem Mar 24, 2024
76d9fdc
bugfix formatting on titanic dataset
annahedstroem Mar 25, 2024
bf7ab44
merge fixes
annahedstroem Mar 25, 2024
d6a9ed9
fixed tests for inverse estimation, debugged shapes
annahedstroem Mar 25, 2024
b4b0156
Merge branch 'main' into xai-inverse-estimation
annahedstroem Mar 25, 2024
138dc46
merge post-fixes, remove old files
annahedstroem Mar 25, 2024
b2c7f27
add print statement
annahedstroem Mar 25, 2024
b7215ac
added base call kwargs as class attributes to base and rewrite invers…
annahedstroem Mar 25, 2024
77eb1a8
replace assert with warning
annahedstroem Mar 25, 2024
66d75af
bugfix evaluate_batch
annahedstroem Mar 25, 2024
6f20997
small fixes wrt assert/ warns
annahedstroem Mar 25, 2024
0cddcd6
added plotting
annahedstroem Mar 26, 2024
e248ad4
fixes
annahedstroem Mar 26, 2024
e7e4493
eval func check and channel first fix
annahedstroem Mar 26, 2024
7214701
added feature for mean/ AUC calc for localisation metric for inverse …
annahedstroem Mar 27, 2024
a516928
add tests
annahedstroem Mar 27, 2024
a55ef15
plotting updates, tiny
annahedstroem Mar 27, 2024
99b2fa1
include area_between_curves flag inverse_estimation.py
annahedstroem Apr 4, 2024
3d27467
Update consistency.py
annahedstroem Apr 14, 2024
7ae5ecf
Update region_perturbation.py
annahedstroem Apr 16, 2024
4dd68ce
Update region_perturbation.py
annahedstroem Apr 16, 2024
c823d4a
Update return typing region_perturbation.py
annahedstroem Apr 16, 2024
24ec8fa
Update region_perturbation.py
annahedstroem Apr 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ It is possible to limit the scope of testing to specific sections of the codebas
Faithfulness metrics using python3.9 (make sure the python versions match in your environment):

```bash
python3 -m tox run -e py39 -- -m evaluate_func -s
python3 -m tox run -e py39 -- -m faithfulness -s
```

For a complete overview of the possible testing scopes, please refer to `pytest.ini`.
Expand Down
2 changes: 1 addition & 1 deletion quantus/functions/explanation_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def f_reduce_axes(a):
inputs = inputs.cpu()

inputs_numpy = inputs.detach().numpy()

for i in range(len(explanation)):
explanation[i] = torch.Tensor(
np.clip(scipy.ndimage.sobel(inputs_numpy[i]), 0, 1)
Expand Down
11 changes: 10 additions & 1 deletion quantus/helpers/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"Faithfulness Correlation": FaithfulnessCorrelation,
"Faithfulness Estimate": FaithfulnessEstimate,
"Pixel-Flipping": PixelFlipping,
"Region Segmentation": RegionPerturbation,
"Region Perturbation": RegionPerturbation,
"Monotonicity-Arya": Monotonicity,
"Monotonicity-Nguyen": MonotonicityCorrelation,
"Selectivity": Selectivity,
Expand Down Expand Up @@ -74,6 +74,15 @@
},
}

# Quantus metrics that include a step-wise 'masking'/ perturbation that is
# based on attribution order/ ranking (and not magnitude).
AVAILABLE_INVERSE_ESTIMATION_METRICS = {
"Pixel-Flipping": PixelFlipping,
"Region Perturbation": RegionPerturbation, # order = 'morf'
"ROAD": ROAD, # return_only_values = True
"Selectivity": Selectivity,
}
#

AVAILABLE_PERTURBATION_FUNCTIONS = {
"baseline_replacement_by_indices": baseline_replacement_by_indices,
Expand Down
32 changes: 32 additions & 0 deletions quantus/helpers/model/pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,38 @@ def sample(
)
return model_copy

def perturb_layer_weights(self, layer_idx: int, noise: float):
"""
Perturb the weights of a specific layer in a PyTorch model.

Parameters
----------
model : torch.nn.Module
The PyTorch model.
layer_idx : int
The index of the layer to perturb.
noise : float
The standard deviation of the Gaussian noise to add to the weights.

Returns
-------
None
"""
original_parameters = self.state_dict()
model_copy = deepcopy(self.model)
model_copy.load_state_dict(original_parameters)

# Get the specific layer.
layer = list(model_copy.modules())[layer_idx]

# Generate Gaussian noise.
noise_tensor = torch.randn_like(layer.weight) * noise

# Add the noise to the layer's weights.
layer.weight.data.add_(noise_tensor)

return model_copy

def add_mean_shift_to_first_layer(
self,
input_shift: Union[int, float],
Expand Down
36 changes: 32 additions & 4 deletions quantus/helpers/perturbation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,23 @@ def __call__(
def make_perturb_func(
perturb_func: PerturbFunc, perturb_func_kwargs: Mapping[str, ...] | None, **kwargs
) -> PerturbFunc | functools.partial:
"""A utility function to save few lines of code during perturbation metric initialization."""
"""
A utility function to save few lines of code during perturbation metric initialization.

Parameters
----------
perturb_func: callable
Perturbation function.
perturb_func_kwargs: dict
Perturbation function kwargs.
kwargs: dict
Perturbation metric kwargs.

Returns
-------
perturb_func: callable
Perturbation function.
"""
if perturb_func_kwargs is not None:
func_kwargs = kwargs.copy()
func_kwargs.update(perturb_func_kwargs)
Expand All @@ -41,7 +57,19 @@ def make_perturb_func(
def make_changed_prediction_indices_func(
return_nan_when_prediction_changes: bool,
) -> Callable[[ModelInterface, np.ndarray, np.ndarray], List[int]]:
"""A utility function to improve static analysis."""
"""
A utility function to improve static analysis.

Parameters
----------
return_nan_when_prediction_changes: boolean
Indicates if metric should return NaN when model prediction changes due to perturbation.

Returns
-------
changed_prediction_indices: callable
Function that returns indices in batch, for which predicted label has changed after applying perturbation.
"""
return functools.partial(
changed_prediction_indices,
return_nan_when_prediction_changes=return_nan_when_prediction_changes,
Expand All @@ -62,15 +90,15 @@ def changed_prediction_indices(
----------
return_nan_when_prediction_changes:
Instance attribute of perturbation metrics.
model:
model: ModelInterface
Model to be used for prediction.
x_batch:
Batch of original inputs provided by user.
x_perturbed:
Batch of inputs after applying perturbation.

Returns
-------

changed_idx:
List of indices in batch, for which predicted label has changed afer.

Expand Down
1 change: 1 addition & 0 deletions quantus/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from quantus.metrics.localisation import *
from quantus.metrics.randomisation import *
from quantus.metrics.robustness import *
from quantus.metrics.inverse_estimation import InverseEstimation
28 changes: 14 additions & 14 deletions quantus/metrics/axiomatic/completeness.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,32 +219,32 @@ def __call__(
Examples:
--------
# Minimal imports.
>> import quantus
>> from quantus import LeNet
>> import torch
>>> import quantus
>>> from quantus import LeNet
>>> import torch

# Enable GPU.
>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
>>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load a pre-trained LeNet classification model (architecture at quantus/helpers/models).
>> model = LeNet()
>> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model"))
>>> model = LeNet()
>>> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model"))

# Load MNIST datasets and make loaders.
>> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True)
>> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24)
>>> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True)
>>> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24)

# Load a batch of inputs and outputs to use for XAI evaluation.
>> x_batch, y_batch = iter(test_loader).next()
>> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy()
>>> x_batch, y_batch = iter(test_loader).next()
>>> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy()

# Generate Saliency attributions of the test set batch of the test set.
>> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1)
>> a_batch_saliency = a_batch_saliency.cpu().numpy()
>>> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1)
>>> a_batch_saliency = a_batch_saliency.cpu().numpy()

# Initialise the metric and evaluate explanations by calling the metric instance.
>> metric = Metric(abs=True, normalise=False)
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency)
>>> metric = Metric(abs=True, normalise=False)
>>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency)
"""
return super().__call__(
model=model,
Expand Down
28 changes: 14 additions & 14 deletions quantus/metrics/axiomatic/input_invariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,32 +199,32 @@ def __call__(
Examples:
--------
# Minimal imports.
>> import quantus
>> from quantus import LeNet
>> import torch
>>> import quantus
>>> from quantus import LeNet
>>> import torch

# Enable GPU.
>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
>>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load a pre-trained LeNet classification model (architecture at quantus/helpers/models).
>> model = LeNet()
>> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model"))
>>> model = LeNet()
>>> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model"))

# Load MNIST datasets and make loaders.
>> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True)
>> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24)
>>> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True)
>>> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24)

# Load a batch of inputs and outputs to use for XAI evaluation.
>> x_batch, y_batch = iter(test_loader).next()
>> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy()
>>> x_batch, y_batch = iter(test_loader).next()
>>> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy()

# Generate Saliency attributions of the test set batch of the test set.
>> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1)
>> a_batch_saliency = a_batch_saliency.cpu().numpy()
>>> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1)
>>> a_batch_saliency = a_batch_saliency.cpu().numpy()

# Initialise the metric and evaluate explanations by calling the metric instance.
>> metric = Metric(abs=True, normalise=False)
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency)
>>> metric = Metric(abs=True, normalise=False)
>>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency)
"""
return super().__call__(
model=model,
Expand Down
28 changes: 14 additions & 14 deletions quantus/metrics/axiomatic/non_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,32 +219,32 @@ def __call__(
Examples:
--------
# Minimal imports.
>> import quantus
>> from quantus import LeNet
>> import torch
>>> import quantus
>>> from quantus import LeNet
>>> import torch

# Enable GPU.
>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
>>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load a pre-trained LeNet classification model (architecture at quantus/helpers/models).
>> model = LeNet()
>> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model"))
>>> model = LeNet()
>>> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model"))

# Load MNIST datasets and make loaders.
>> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True)
>> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24)
>>> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True)
>>> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24)

# Load a batch of inputs and outputs to use for XAI evaluation.
>> x_batch, y_batch = iter(test_loader).next()
>> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy()
>>> x_batch, y_batch = iter(test_loader).next()
>>> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy()

# Generate Saliency attributions of the test set batch of the test set.
>> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1)
>> a_batch_saliency = a_batch_saliency.cpu().numpy()
>>> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1)
>>> a_batch_saliency = a_batch_saliency.cpu().numpy()

# Initialise the metric and evaluate explanations by calling the metric instance.
>> metric = Metric(abs=True, normalise=False)
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency)
>>> metric = Metric(abs=True, normalise=False)
>>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency)
"""
return super().__call__(
model=model,
Expand Down
30 changes: 15 additions & 15 deletions quantus/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(

if normalise_func_kwargs is not None:
normalise_func = functools.partial(normalise_func, **normalise_func_kwargs)

# Run deprecation warnings.
warn.deprecation_warnings(kwargs)
warn.check_kwargs(kwargs)
Expand Down Expand Up @@ -233,32 +233,32 @@ def __call__(
Examples:
--------
# Minimal imports.
>> import quantus
>> from quantus import LeNet
>> import torch
>>> import quantus
>>> from quantus import LeNet
>>> import torch

# Enable GPU.
>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
>>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load a pre-trained LeNet classification model (architecture at quantus/helpers/models).
>> model = LeNet()
>> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model"))
>>> model = LeNet()
>>> model.load_state_dict(torch.load("tutorials/assets/pytests/mnist_model"))

# Load MNIST datasets and make loaders.
>> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True)
>> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24)
>>> test_set = torchvision.datasets.MNIST(root='./sample_data', download=True)
>>> test_loader = torch.utils.data.DataLoader(test_set, batch_size=24)

# Load a batch of inputs and outputs to use for XAI evaluation.
>> x_batch, y_batch = iter(test_loader).next()
>> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy()
>>> x_batch, y_batch = iter(test_loader).next()
>>> x_batch, y_batch = x_batch.cpu().numpy(), y_batch.cpu().numpy()

# Generate Saliency attributions of the test set batch of the test set.
>> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1)
>> a_batch_saliency = a_batch_saliency.cpu().numpy()
>>> a_batch_saliency = Saliency(model).attribute(inputs=x_batch, target=y_batch, abs=True).sum(axis=1)
>>> a_batch_saliency = a_batch_saliency.cpu().numpy()

# Initialise the metric and evaluate explanations by calling the metric instance.
>> metric = Metric(abs=True, normalise=False)
>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency)
>>> metric = Metric(abs=True, normalise=False)
>>> scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_saliency)
"""
# Run deprecation warnings.
warn.deprecation_warnings(kwargs)
Expand Down
Loading
Loading