Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

279 embeddings #321

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
9d953dd
first attempt at rise with embeddings
cwmeijer Apr 12, 2022
d89ed71
update RISE embeddings experimentation notebook
egpbos Apr 14, 2022
89c6398
finish RISE embeddings WIP notebook 1
egpbos Apr 20, 2022
636d567
move RISE embedding test notebook
egpbos Apr 20, 2022
acc00cb
ignore temporary jpgs
egpbos Apr 20, 2022
dad11db
add new WIP notebook for RISE-style embedding explainer
egpbos Apr 20, 2022
9c2e921
add set_all_the_seeds function
egpbos Apr 21, 2022
ce64457
fixup: remove divide by p1 in explain5
egpbos Apr 21, 2022
0005a1e
WIP: continue with second RISE-embeddings notebook
egpbos Apr 21, 2022
acefe8b
Add two labradoodle images for testing
egpbos May 10, 2022
8d754c5
WIP embeddings: new notebook to get consistent results
egpbos May 10, 2022
4b7f62e
Add files via upload
cwmeijer May 12, 2022
6d20878
Second version of 2-class try out, now with some more hope
egpbos May 12, 2022
70c4c32
add doggiedog image
egpbos Jun 14, 2022
ccfac9d
add another dog
egpbos Jun 14, 2022
6497b99
add new embeddings notebook with cool results
egpbos Jun 14, 2022
7147899
add percentage filter to distances in power20 notebook
egpbos Jun 14, 2022
92e05df
refactor generate masks for images to static method
cwmeijer Jun 29, 2022
3ca342f
add distance explanation function, refs #279
cwmeijer Jun 29, 2022
e65e6c0
add notebook to work with new dianna impl, refs #279
cwmeijer Jun 29, 2022
5c3e749
fix typo, refs #279
cwmeijer Jun 30, 2022
0b225b9
fix typo in notebook, refs #279
cwmeijer Jun 30, 2022
f539145
add refactored notebook using dianna's implementation, refs #279
cwmeijer Jun 30, 2022
43537e7
add first distance test
cwmeijer Jul 12, 2022
c2b661e
add tests for distance explainer functionality
cwmeijer Jul 13, 2022
f89ef6b
add exact result test for distance
cwmeijer Jul 14, 2022
95c6f65
refactor distance
cwmeijer Jul 14, 2022
d52f462
notebook wip
cwmeijer Sep 8, 2022
8bd6324
add neutral value to return values of explain distance
cwmeijer Sep 14, 2022
d8709e8
add range parameters for distance in dianna
cwmeijer Sep 20, 2022
785a058
add support for 2 ranges wrt distance weights
cwmeijer Oct 5, 2022
f5a1ae4
add support for PIL images
cwmeijer Oct 5, 2022
531d848
no more mask weighting for distance
cwmeijer Oct 25, 2022
5df55ac
log statistics to explainer object instead of print
cwmeijer Nov 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,6 @@ venv
venv3

.python-version

# testing jpgs
embedding_WIP/*.jpg*
7 changes: 7 additions & 0 deletions dianna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import importlib
import logging
from . import utils
from .methods.distance import DistanceExplainer


logging.getLogger(__name__).addHandler(logging.NullHandler())
Expand Down Expand Up @@ -75,6 +76,12 @@ def explain_text(model_or_function, input_data, method, labels=(1,), **kwargs):
explain_text_kwargs = utils.get_kwargs_applicable_to_function(explainer.explain_text, kwargs)
return explainer.explain_text(model_or_function, input_data, labels, **explain_text_kwargs)

def explain_image_distance(model_or_function, input_data, embedded_reference, **kwargs):
method_kwargs = utils.get_kwargs_applicable_to_function(DistanceExplainer.__init__, kwargs)
explainer = DistanceExplainer(**method_kwargs)
explain_distance_kwargs = utils.get_kwargs_applicable_to_function(explainer.explain_image_distance, kwargs)
return explainer.explain_image_distance(model_or_function, input_data, embedded_reference, **explain_distance_kwargs)


def _get_explainer(method, kwargs):
method_submodule = importlib.import_module(f'dianna.methods.{method.lower()}')
Expand Down
156 changes: 156 additions & 0 deletions dianna/methods/distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import os
import random
from urllib.parse import urlparse
import numpy as np
import tensorflow as tf
from keras import backend as K
from keras.preprocessing import image
from matplotlib import pyplot as plt
from requests import get
from skimage.transform import resize
from tensorflow.keras import backend as K
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.applications.resnet50 import decode_predictions
from tensorflow.keras.applications.resnet50 import preprocess_input
from tqdm import tqdm
from dianna.methods.rise import generate_masks_for_images
from dianna import utils
from sklearn.metrics import pairwise_distances


class DistanceExplainer:
# axis labels required to be present in input image data
required_labels = ('channels',)

def __init__(self, n_masks=1000, feature_res=8, p_keep=.5, # pylint: disable=too-many-arguments
mask_selection_range_max=0.2, mask_selection_range_min=0, mask_selection_negative_range_max=1,
mask_selection_negative_range_min=0.8, axis_labels=None, batch_size=10,
preprocess_function=None):
self.n_masks = n_masks
self.feature_res = feature_res
self.p_keep = p_keep
self.preprocess_function = preprocess_function
self.masks = None
self.predictions = None
self.axis_labels = axis_labels if axis_labels is not None else []
self.mask_selection_range_max = mask_selection_range_max
self.mask_selection_range_min = mask_selection_range_min
self.mask_selection_negative_range_max = mask_selection_negative_range_max
self.mask_selection_negative_range_min = mask_selection_negative_range_min
self.batch_size = batch_size

def explain_image_distance(self, model_or_function, input_data, embedded_reference, **explain_distance_kwargs):
"""

:param model_or_function:
:param input_data:
:param embedded_reference:
:param explain_distance_kwargs:
:return: saliency map and the neutral value within the saliency map which indicates the parts of the image that
neither bring the image closer nor further away from the embedded reference.
"""
full_preprocess_function, input_data = self._prepare_input_data(input_data)
runner = utils.get_function(model_or_function, preprocess_function=full_preprocess_function)
active_p_keep = 0.5 if self.p_keep is None else self.p_keep # Could autotune here (See #319)

# data shape without batch axis and channel axis
img_shape = input_data.shape[1:3]
# Expose masks for to make user inspection possible
self.masks = generate_masks_for_images(img_shape, active_p_keep, self.n_masks, self.feature_res)
# Make sure multiplication is being done for correct axes
masked = input_data * self.masks

batch_predictions = []

for i in tqdm(range(0, self.n_masks, self.batch_size), desc='Explaining'):
new_predictions = runner(masked[i:i + self.batch_size])
batch_predictions.append(new_predictions)

self.predictions = np.concatenate(batch_predictions)

lowest_distances_masks, lowest_mask_weights = self._get_lowest_distance_masks_and_weights(embedded_reference,
self.predictions, self.masks,
self.mask_selection_range_min,
self.mask_selection_range_max)
highest_distances_masks, highest_mask_weights = self._get_lowest_distance_masks_and_weights(embedded_reference,
self.predictions, self.masks,
self.mask_selection_negative_range_min,
self.mask_selection_negative_range_max)

def describe(x, name):
return f'Description of {name}\nmean:{np.mean(x)}\nstd:{np.std(x)}\nmin:{np.min(x)}\nmax:{np.max(x)}'
self.statistics = describe(highest_mask_weights, 'highest_mask_weights') +'\n' + describe(lowest_mask_weights, 'lowest_mask_weights')

unnormalized_sal_lowest = np.mean(lowest_distances_masks, axis=0)
unnormalized_sal_highest = np.mean(highest_distances_masks, axis=0)
unnormalized_sal = unnormalized_sal_lowest - unnormalized_sal_highest

saliency = unnormalized_sal

input_prediction = runner(input_data)
input_distance = pairwise_distances(input_prediction, embedded_reference, metric='cosine') / 2
neutral_value = np.exp(-input_distance)

return saliency, neutral_value

@staticmethod
def _get_lowest_distance_masks_and_weights(embedded_reference, predictions, masks, mask_selection_range_min,
mask_selection_range_max):
distances = pairwise_distances(predictions, embedded_reference,
metric='cosine') / 2 # divide by 2 to have [0.1] output range
lowest_distances_indices = np.argsort(distances, axis=0)[
int(len(predictions) * mask_selection_range_min)
:int(len(predictions) * mask_selection_range_max)]
mask_weights = np.exp(-distances[lowest_distances_indices])
lowest_distances_masks = masks[lowest_distances_indices]
return lowest_distances_masks, mask_weights

def _prepare_input_data(self, input_data):
input_data_xarray = utils.to_xarray(input_data, self.axis_labels, DistanceExplainer.required_labels)
input_data_xarray_expanded = input_data_xarray.expand_dims('batch', 0)
# ensure channels axis is last and keep track of where it was so we can move it back
channels_axis_index = input_data_xarray_expanded.dims.index('channels')
prepared_input_data = utils.move_axis(input_data_xarray_expanded, 'channels', -1)
# create preprocessing function that puts model input generated by RISE into the right shape and dtype,
# followed by running the user's preprocessing function
full_preprocess_function = self._get_full_preprocess_function(channels_axis_index, prepared_input_data.dtype)
return full_preprocess_function, prepared_input_data

def _prepare_image_data(self, input_data):
"""Transforms the data to be of the shape and type RISE expects.

Args:
input_data (xarray): Data to be explained

Returns:
transformed input data, preprocessing function to use with utils.get_function()
"""
# ensure channels axis is last and keep track of where it was so we can move it back
channels_axis_index = input_data.dims.index('channels')
input_data = utils.move_axis(input_data, 'channels', -1)
# create preprocessing function that puts model input generated by RISE into the right shape and dtype,
# followed by running the user's preprocessing function
full_preprocess_function = self._get_full_preprocess_function(channels_axis_index, input_data.dtype)
return input_data, full_preprocess_function

def _get_full_preprocess_function(self, channel_axis_index, dtype):
"""Creates a full preprocessing function.

Creates a preprocessing function that incorporates both the (optional) user's
preprocessing function, as well as any needed dtype and shape conversions

Args:
channel_axis_index (int): Axis index of the channels in the input data
dtype (type): Data type of the input data (e.g. np.float32)

Returns:
Function that first ensures the data has the same shape and type as the input data,
then runs the users' preprocessing function
"""

def moveaxis_function(data):
return utils.move_axis(data, 'channels', channel_axis_index).astype(dtype).values

if self.preprocess_function is None:
return moveaxis_function
return lambda data: self.preprocess_function(moveaxis_function(data))
59 changes: 30 additions & 29 deletions dianna/methods/rise.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,33 @@ def _upscale(grid_i, up_size):
return resize(grid_i, up_size, order=1, mode='reflect', anti_aliasing=False)


def generate_masks_for_images(input_size, p_keep, n_masks, feature_res):
"""Generates a set of random masks to mask the input data.

Args:
input_size (int): Size of a single sample of input data, for images without the channel axis.

Returns:
The generated masks (np.ndarray)
"""
cell_size = np.ceil(np.array(input_size) / feature_res)
up_size = (feature_res + 1) * cell_size

grid = np.random.choice(a=(True, False), size=(n_masks, feature_res, feature_res),
p=(p_keep, 1 - p_keep))
grid = grid.astype('float32')

masks = np.empty((n_masks, *input_size), dtype=np.float32)

for i in range(n_masks):
y = np.random.randint(0, cell_size[0])
x = np.random.randint(0, cell_size[1])
# Linear upsampling and cropping
masks[i, :, :] = _upscale(grid[i], up_size)[y:y + input_size[0], x:x + input_size[1]]
masks = masks.reshape(-1, *input_size, 1)
return masks


class RISE:
"""RISE implementation based on https://github.com/eclique/RISE/blob/master/Easy_start.ipynb."""
# axis labels required to be present in input image data
Expand Down Expand Up @@ -149,7 +176,7 @@ def explain_image(self, model_or_function, input_data, labels=None, batch_size=1
# data shape without batch axis and channel axis
img_shape = input_data.shape[1:3]
# Expose masks for to make user inspection possible
self.masks = self.generate_masks_for_images(img_shape, active_p_keep, self.n_masks)
self.masks = generate_masks_for_images(img_shape, active_p_keep, self.n_masks, self.feature_res)

# Make sure multiplication is being done for correct axes
masked = input_data * self.masks
Expand Down Expand Up @@ -180,7 +207,7 @@ def _determine_p_keep_for_images(self, input_data, runner, n_masks=100):
def _calculate_mean_class_std_for_images(self, p_keep, runner, input_data, n_masks):
batch_size = 50
img_shape = input_data.shape[1:3]
masks = self.generate_masks_for_images(img_shape, p_keep, n_masks)
masks = generate_masks_for_images(img_shape, p_keep, n_masks, self.feature_res)
masked = input_data * masks
predictions = []
for i in range(0, n_masks, batch_size):
Expand All @@ -189,33 +216,7 @@ def _calculate_mean_class_std_for_images(self, p_keep, runner, input_data, n_mas
predictions.append(current_predictions.max(axis=1))
predictions = np.concatenate(predictions)
std_per_class = predictions.std()
return np.mean(std_per_class)

def generate_masks_for_images(self, input_size, p_keep, n_masks):
"""Generates a set of random masks to mask the input data.

Args:
input_size (int): Size of a single sample of input data, for images without the channel axis.

Returns:
The generated masks (np.ndarray)
"""
cell_size = np.ceil(np.array(input_size) / self.feature_res)
up_size = (self.feature_res + 1) * cell_size

grid = np.random.choice(a=(True, False), size=(n_masks, self.feature_res, self.feature_res),
p=(p_keep, 1 - p_keep))
grid = grid.astype('float32')

masks = np.empty((n_masks, *input_size), dtype=np.float32)

for i in range(n_masks):
y = np.random.randint(0, cell_size[0])
x = np.random.randint(0, cell_size[1])
# Linear upsampling and cropping
masks[i, :, :] = _upscale(grid[i], up_size)[y:y + input_size[0], x:x + input_size[1]]
masks = masks.reshape(-1, *input_size, 1)
return masks
return np.mean(std_per_class)

def _prepare_image_data(self, input_data):
"""Transforms the data to be of the shape and type RISE expects.
Expand Down
18 changes: 16 additions & 2 deletions dianna/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import inspect

import PIL
import numpy as np


def get_function(model_or_function, preprocess_function=None):
"""Converts input to callable function.
Expand Down Expand Up @@ -36,9 +39,20 @@ def get_kwargs_applicable_to_function(function, kwargs):
if key in inspect.getfullargspec(function).args}


def _get_num_dims(data):
if hasattr(data, 'ndim'):
return data.ndim
if hasattr(data, 'shape'):
return len(data.shape)
raise TypeError('Unsupported data type. Supported types are numpy arrays or PIL images and similar.')


def to_xarray(data, axis_labels, required_labels=None):
"""Converts numpy data and axes labels to an xarray object."""
if isinstance(axis_labels, dict):
if isinstance(data, PIL.Image.Image):
data = np.array(data)
labels = ['dim_0', 'dim_1', 'channels']
elif isinstance(axis_labels, dict):
# key = axis index, value = label
# not all axes have to be present in the input, but we need to provide
# a name for each axis
Expand All @@ -47,7 +61,7 @@ def to_xarray(data, axis_labels, required_labels=None):
for index in indices:
if index < 0:
axis_labels[data.ndim + index] = axis_labels.pop(index)
labels = [axis_labels[index] if index in axis_labels else f'dim_{index}' for index in range(data.ndim)]
labels = [axis_labels[index] if index in axis_labels else f'dim_{index}' for index in range(_get_num_dims(data))]
else:
labels = list(axis_labels)

Expand Down
Binary file added embedding_WIP/IMG_4531-e1549365547619.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added embedding_WIP/Tike-Mini-Labradoodle.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added embedding_WIP/cardog.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading