Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract visualization functions from Explainer to standalone function… #55

Merged
merged 11 commits into from
Feb 6, 2023
49 changes: 37 additions & 12 deletions autoxai/array_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,46 @@
import torch


def convert_float_to_uint8(array: np.ndarray) -> np.ndarray:
"""Convert numpy array with float values to uint8 with scaled values.
def standardize_array(array: np.ndarray) -> np.ndarray:
"""Standardize array values to range [0-1].

Args:
array: Numpy array with float values.
array: Numpy array of floats.

Returns:
Numpy array with scaled values in uint8.
Numpy array with scaled values.

Raises:
ValueError: if array is not of type np.float.
"""
if not array.dtype == np.dtype(float):
raise ValueError(
f"Array should be of type: np.float, current type: {array.dtype}"
)

return (array - np.min(array)) / (
adamwawrzynski marked this conversation as resolved.
Show resolved Hide resolved
(np.max(array) - np.min(array)) + sys.float_info.epsilon
)


def convert_standardized_float_to_uint8(array: np.ndarray) -> np.ndarray:
"""Convert float standardize float array to uint8 with values scaling.

Args:
array: Numpy array of floats.

Returns:
Numpy array with scaled values with type np.uint8.

Raises:
ValueError: if array is not of type np.float.
"""
return (
(
(array - np.min(array))
/ ((np.max(array) - np.min(array)) + sys.float_info.epsilon)
if not array.dtype == np.dtype(float):
raise ValueError(
f"Array should be of type: np.float, current type: {array.dtype}"
)
* 255
).astype(np.uint8)

return (array * 255).astype(np.uint8)


def retain_only_positive(array: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -86,8 +110,9 @@ def resize_attributes(
single_channel_attributes: np.ndarray = np.array(
cv2.resize(
attributes,
(dest_width, dest_height),
)
(dest_height, dest_width),
),
dtype=np.dtype(float),
)

return single_channel_attributes
Expand Down
14 changes: 9 additions & 5 deletions autoxai/callbacks/wandb_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import numpy as np
import pytorch_lightning as pl
import torch
import wandb
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader

from autoxai.array_utils import convert_float_to_uint8
import wandb
from autoxai.array_utils import convert_standardized_float_to_uint8, standardize_array
from autoxai.context_manager import AutoXaiExplainer, ExplainerWithParams
from autoxai.explainer.base_explainer import CVExplainer
from autoxai.visualizer import mean_channels_visualization

AttributeMapType = Dict[str, List[np.ndarray]]
CaptionMapType = Dict[str, List[str]]
Expand Down Expand Up @@ -112,13 +112,17 @@ def explain( # pylint: disable = (too-many-arguments)
explainer_name: str = explainer.explainer_name.name
explainer_attributes: torch.Tensor = attributes[explainer_name]
caption_dict[explainer_name].append(f"label: {target_label}")
figure = CVExplainer.visualize(
figure = mean_channels_visualization(
attributions=explainer_attributes,
transformed_img=item,
)
figures_dict[explainer_name].append(figure)
standardized_attr = standardize_array(
explainer_attributes.detach().cpu().numpy()
)

attributes_dict[explainer_name].append(
convert_float_to_uint8(explainer_attributes.detach().cpu().numpy())
convert_standardized_float_to_uint8(standardized_attr),
)

return attributes_dict, caption_dict, figures_dict
Expand Down
89 changes: 1 addition & 88 deletions autoxai/explainer/base_explainer.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,10 @@
"""Abstract Explainer class."""
import logging
from abc import ABC, abstractmethod
from typing import Optional, Tuple, TypeVar
from typing import TypeVar

import matplotlib
import numpy as np
import torch
from captum._utils.typing import TargetType

from autoxai.array_utils import (
convert_float_to_uint8,
normalize_attributes,
resize_attributes,
retain_only_positive,
transpose_array,
)
from autoxai.logger import create_logger

_LOGGER: Optional[logging.Logger] = None


def log() -> logging.Logger:
"""Get or create logger."""
# pylint: disable = global-statement
global _LOGGER
if _LOGGER is None:
_LOGGER = create_logger(__name__)
return _LOGGER


class CVExplainer(ABC):
"""Abstract explainer class."""
Expand Down Expand Up @@ -61,70 +38,6 @@ def algorithm_name(self) -> str:
"""
return type(self).__name__

@classmethod
def visualize(
cls,
attributions: torch.Tensor,
transformed_img: torch.Tensor,
title: str = "",
figsize: Tuple[int, int] = (8, 8),
alpha: float = 0.5,
only_positive_attr: bool = True,
) -> matplotlib.pyplot.Figure:
"""Create image with calculated features.

Args:
attributions: Features.
transformed_img: Image in shape (C x H x W) or (H x W).
title: Title of the figure. Defaults to "".
figsize: Tuple with size of figure. Defaults to (8, 8).
alpha: Opacity level. Defaults to 0.5,
only_positive_attr: Whether to display only positive or all attributes.
Defaults to True.

Returns:
Image with paired figures: original image and features heatmap.
"""
attributes_matrix: np.ndarray = attributions.detach().cpu().numpy()
transformed_img_np: np.ndarray = transformed_img.detach().cpu().numpy()

single_channel_attributes: np.ndarray = normalize_attributes(
attributes=attributes_matrix,
)

if only_positive_attr:
single_channel_attributes = retain_only_positive(
array=single_channel_attributes
)

resized_attributes: np.ndarray = resize_attributes(
attributes=single_channel_attributes,
dest_height=transformed_img_np.shape[1],
dest_width=transformed_img_np.shape[2],
)

# standardize attributes to uint8 type and back-scale them to range 0-1
grayscale_attributes = convert_float_to_uint8(resized_attributes) / 255

# transpoze image from (C x H x W) shape to (H x W x C) to matplotlib imshow
normalized_transformed_img = transpose_array(
convert_float_to_uint8(transformed_img_np)
)

figure = matplotlib.figure.Figure(figsize=figsize)
axis = figure.subplots()
axis.imshow(normalized_transformed_img)
heatmap_plot = axis.imshow(
grayscale_attributes, cmap=matplotlib.cm.jet, vmin=0, vmax=1, alpha=alpha
)

figure.colorbar(heatmap_plot, label="Pixel relevance")
axis.get_xaxis().set_visible(False)
axis.get_yaxis().set_visible(False)
axis.set_title(title)

return figure


CVExplainerT = TypeVar("CVExplainerT", bound=CVExplainer)
"""CVExplainer subclass type."""
Loading