Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance XAI (CLI, ExplainParameters, test) #1941

Merged
merged 8 commits into from
Apr 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ All notable changes to this project will be documented in this file.

- Clean up and refactor the output of the OTX CLI (<https://github.com/openvinotoolkit/training_extensions/pull/1946>)
- Enhance DetCon logic and SupCon for semantic segmentation(<https://github.com/openvinotoolkit/training_extensions/pull/1958>)
- Extend OTX explain CLI (<https://github.com/openvinotoolkit/training_extensions/pull/1941>)

### Bug fixes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ The command below will evaluate the trained model on the provided dataset:
Explanation
***********

``otx explain`` runs the explainable AI (XAI) algorithm of a model on the specific dataset. It helps explain the model's decision-making process in a way that is easily understood by humans.
``otx explain`` runs the explainable AI (XAI) algorithm on a specific model-dataset pair. It helps explain the model's decision-making process in a way that is easily understood by humans.

With the ``--help`` command, you can list additional information, such as its parameters common to all model templates:

Expand All @@ -422,8 +422,12 @@ With the ``--help`` command, you can list additional information, such as its pa
Load model weights from previously saved checkpoint.
--explain-algorithm EXPLAIN_ALGORITHM
Explain algorithm name, currently support ['activationmap', 'eigencam', 'classwisesaliencymap']. For Openvino task, default method will be selected.
--process-saliency-maps PROCESS_SALIENCY_MAPS
Processing of saliency map includes (1) resizing to input image resolution and (2) applying a colormap. Depending on the number of targets to explain, this might take significant time.
--explain-all-classes EXPLAIN_ALL_CLASSES
Provides explanations for all classes. Otherwise, explains only predicted classes. This feature is supported by algorithms that can generate explanations per each class.
--overlay-weight OVERLAY_WEIGHT
Weight of the saliency map when overlaying the saliency map.
Weight of the saliency map when overlaying the input image with saliency map.


The command below will generate saliency maps (heatmaps with red colored areas of focus) of the trained model on the provided dataset and save the resulting images to ``save-explanation-to`` path:
Expand Down
3 changes: 2 additions & 1 deletion otx/algorithms/classification/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from otx.algorithms.common.utils import embed_ir_model_data
from otx.algorithms.common.utils.logger import get_logger
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.explain_parameters import ExplainParameters
from otx.api.entities.inference_parameters import (
InferenceParameters,
default_progress_callback,
Expand Down Expand Up @@ -163,7 +164,7 @@ def infer(
def explain(
self,
dataset: DatasetEntity,
explain_parameters: Optional[InferenceParameters] = None,
explain_parameters: Optional[ExplainParameters] = None,
) -> DatasetEntity:
"""Main explain function of OTX Classification Task."""
logger.info("called explain()")
Expand Down
9 changes: 8 additions & 1 deletion otx/algorithms/classification/tasks/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from otx.api.entities.annotation import AnnotationSceneEntity
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.explain_parameters import ExplainParameters
from otx.api.entities.inference_parameters import (
InferenceParameters,
default_progress_callback,
Expand Down Expand Up @@ -254,7 +255,7 @@ def infer(
def explain(
self,
dataset: DatasetEntity,
explain_parameters: Optional[InferenceParameters] = None,
explain_parameters: Optional[ExplainParameters] = None,
) -> DatasetEntity:
"""Explain function of ClassificationOpenVINOTask."""

Expand All @@ -269,6 +270,12 @@ def explain(
dataset_size = len(dataset)
for i, dataset_item in enumerate(dataset, 1):
predicted_scene, _, saliency_map, _, _ = self.inferencer.predict(dataset_item.numpy)
if saliency_map is None:
raise RuntimeError(
"There is no Saliency Map in OpenVINO IR model output. "
"Please export model to OpenVINO IR with dump_features"
)

item_labels = predicted_scene.annotations[0].get_labels()
dataset_item.append_labels(item_labels)
add_saliency_maps_to_dataset_item(
Expand Down
3 changes: 2 additions & 1 deletion otx/algorithms/common/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from otx.algorithms.common.utils import UncopiableDefaultDict
from otx.algorithms.common.utils.logger import get_logger
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.explain_parameters import ExplainParameters
from otx.api.entities.inference_parameters import InferenceParameters
from otx.api.entities.label import LabelEntity
from otx.api.entities.metrics import MetricsGroup
Expand Down Expand Up @@ -208,7 +209,7 @@ def export(
def explain(
self,
dataset: DatasetEntity,
explain_parameters: Optional[InferenceParameters] = None,
explain_parameters: Optional[ExplainParameters] = None,
) -> DatasetEntity:
"""Main explain function of OTX Task."""
raise NotImplementedError
Expand Down
4 changes: 2 additions & 2 deletions otx/algorithms/detection/adapters/mmdet/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
# SPDX-License-Identifier: Apache-2.0
#

from .det_saliency_map_hook import DetSaliencyMapHook
from .det_class_probability_map_hook import DetClassProbabilityMapHook

__all__ = ["DetSaliencyMapHook"]
__all__ = ["DetClassProbabilityMapHook"]
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# pylint: disable=too-many-locals


class DetSaliencyMapHook(BaseRecordingForwardHook):
class DetClassProbabilityMapHook(BaseRecordingForwardHook):
"""Saliency map hook for object detection models."""

def __init__(self, module: torch.nn.Module) -> None:
Expand Down Expand Up @@ -116,7 +116,7 @@ def forward_single(x, cls_convs, conv_cls):
else:
raise NotImplementedError(
"Not supported detection head provided. "
"DetSaliencyMapHook supports only the following single stage detectors: "
"DetClassProbabilityMap supports only the following single stage detectors: "
"YOLOXHead, ATSSHead, SSDHead, VFNetHead."
)
return cls_scores
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
from otx.algorithms.common.utils.logger import get_logger
from otx.algorithms.common.utils.task_adapt import map_class_names
from otx.algorithms.detection.adapters.mmdet.hooks.det_saliency_map_hook import (
DetSaliencyMapHook,
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
DetClassProbabilityMapHook,
)

from .l2sp_detector_mixin import L2SPDetectorMixin
Expand Down Expand Up @@ -99,7 +99,7 @@ def custom_atss__simple_test(ctx, self, img, img_metas, **kwargs):
if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(feat)
cls_scores = outs[0]
saliency_map = DetSaliencyMapHook(self).func(feature_map=cls_scores, cls_scores_provided=True)
saliency_map = DetClassProbabilityMapHook(self).func(feature_map=cls_scores, cls_scores_provided=True)
return (*bbox_results, feature_vector, saliency_map)

return bbox_results
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
from otx.algorithms.common.utils.logger import get_logger
from otx.algorithms.common.utils.task_adapt import map_class_names
from otx.algorithms.detection.adapters.mmdet.hooks.det_saliency_map_hook import (
DetSaliencyMapHook,
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
DetClassProbabilityMapHook,
)

from .l2sp_detector_mixin import L2SPDetectorMixin
Expand Down Expand Up @@ -156,7 +156,7 @@ def custom_single_stage_detector__simple_test(ctx, self, img, img_metas, **kwarg
if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(feat)
cls_scores = outs[0]
saliency_map = DetSaliencyMapHook(self).func(cls_scores, cls_scores_provided=True)
saliency_map = DetClassProbabilityMapHook(self).func(cls_scores, cls_scores_provided=True)
return (*bbox_results, feature_vector, saliency_map)

return bbox_results
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
from otx.algorithms.common.utils.logger import get_logger
from otx.algorithms.common.utils.task_adapt import map_class_names
from otx.algorithms.detection.adapters.mmdet.hooks.det_saliency_map_hook import (
DetSaliencyMapHook,
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
DetClassProbabilityMapHook,
)

from .l2sp_detector_mixin import L2SPDetectorMixin
Expand Down Expand Up @@ -137,7 +137,7 @@ def custom_yolox__simple_test(ctx, self, img, img_metas, **kwargs):
if ctx.cfg["dump_features"]:
feature_vector = FeatureVectorHook.func(feat)
cls_scores = outs[0]
saliency_map = DetSaliencyMapHook(self).func(cls_scores, cls_scores_provided=True)
saliency_map = DetClassProbabilityMapHook(self).func(cls_scores, cls_scores_provided=True)
return (*bbox_results, feature_vector, saliency_map)

return bbox_results
Expand Down
11 changes: 6 additions & 5 deletions otx/algorithms/detection/adapters/mmdet/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@
SemiSLDetectionConfigurer,
)
from otx.algorithms.detection.adapters.mmdet.datasets import ImageTilingDataset
from otx.algorithms.detection.adapters.mmdet.hooks.det_saliency_map_hook import (
DetSaliencyMapHook,
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
DetClassProbabilityMapHook,
)
from otx.algorithms.detection.adapters.mmdet.utils.builder import build_detector
from otx.algorithms.detection.adapters.mmdet.utils.config_utils import (
Expand All @@ -74,6 +74,7 @@
from otx.api.configuration import cfg_helper
from otx.api.configuration.helper.utils import config_to_bytes, ids_to_strings
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.explain_parameters import ExplainParameters
from otx.api.entities.inference_parameters import InferenceParameters
from otx.api.entities.model import (
ModelEntity,
Expand Down Expand Up @@ -392,7 +393,7 @@ def hook(module, inp, outp):
if isinstance(raw_model, TwoStageDetector):
saliency_hook = ActivationMapHook(feature_model)
else:
saliency_hook = DetSaliencyMapHook(feature_model)
saliency_hook = DetClassProbabilityMapHook(feature_model)

if not dump_features:
feature_vector_hook: Union[nullcontext, BaseRecordingForwardHook] = nullcontext()
Expand Down Expand Up @@ -540,12 +541,12 @@ def export(
def explain(
self,
dataset: DatasetEntity,
explain_parameters: Optional[InferenceParameters] = None,
explain_parameters: Optional[ExplainParameters] = None,
) -> DatasetEntity:
"""Main explain function of MMDetectionTask."""

explainer_hook_selector = {
"classwisesaliencymap": DetSaliencyMapHook,
"classwisesaliencymap": DetClassProbabilityMapHook,
"eigencam": EigenCamHook,
"activationmap": ActivationMapHook,
}
Expand Down
8 changes: 7 additions & 1 deletion otx/algorithms/detection/adapters/openvino/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from otx.api.configuration.helper.utils import config_to_bytes
from otx.api.entities.annotation import AnnotationSceneEntity
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.explain_parameters import ExplainParameters
from otx.api.entities.inference_parameters import (
InferenceParameters,
default_progress_callback,
Expand Down Expand Up @@ -434,7 +435,7 @@ def infer(
def explain(
self,
dataset: DatasetEntity,
explain_parameters: Optional[InferenceParameters] = None,
explain_parameters: Optional[ExplainParameters] = None,
) -> DatasetEntity:
"""Explain function of OpenVINODetectionTask."""
logger.info("Start OpenVINO explain")
Expand All @@ -453,6 +454,11 @@ def explain(
dataset_item.append_annotations(predicted_scene.annotations)
update_progress_callback(int(i / dataset_size * 100), None)
_, saliency_map = features
if saliency_map is None:
raise RuntimeError(
"There is no Saliency Map in OpenVINO IR model output. "
"Please export model to OpenVINO IR with dump_features"
)

labels = self.task_environment.get_labels().copy()
if saliency_map.shape[0] == len(labels) + 1:
Expand Down
3 changes: 2 additions & 1 deletion otx/algorithms/detection/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from otx.api.configuration.helper.utils import ids_to_strings
from otx.api.entities.annotation import Annotation
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.explain_parameters import ExplainParameters
from otx.api.entities.id import ID
from otx.api.entities.inference_parameters import InferenceParameters
from otx.api.entities.label import Domain, LabelEntity
Expand Down Expand Up @@ -259,7 +260,7 @@ def export(
def explain(
self,
dataset: DatasetEntity,
explain_parameters: Optional[InferenceParameters] = None,
explain_parameters: Optional[ExplainParameters] = None,
) -> DatasetEntity:
"""Main explain function of OTX Task."""
raise NotImplementedError
Expand Down
33 changes: 33 additions & 0 deletions otx/api/entities/explain_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""This module define the Explain entity."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#


from dataclasses import dataclass
from typing import Any, Callable, Optional


# pylint: disable=unused-argument
def default_progress_callback(progress: int, score: Optional[float] = None):
"""This is the default progress callback for OptimizationParameters."""


@dataclass
class ExplainParameters:
"""Explain parameters.

Attributes:
explainer: Explain algorithm to be used in explanation mode.
Will be converted automatically to lowercase.
process_saliency_maps: Processing of saliency map includes (1) resize to input image resolution
and (2) apply a colormap.
explain_predicted_classes: Provides explanations only for predicted classes.
Otherwise, explain all classes.
"""

update_progress: Callable[[int, Optional[float]], Any] = default_progress_callback

explainer: str = ""
process_saliency_maps: bool = False
explain_predicted_classes: bool = True
1 change: 0 additions & 1 deletion otx/api/entities/inference_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class InferenceParameters:
is_evaluation: bool = False
update_progress: Callable[[int, Optional[float]], Any] = default_progress_callback

# TODO(negvet): use separate ExplainParameters dataclass for this
explainer: str = ""
process_saliency_maps: bool = False
explain_predicted_classes: bool = True
4 changes: 2 additions & 2 deletions otx/api/usecases/tasks/interfaces/explain_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import abc

from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.inference_parameters import InferenceParameters
from otx.api.entities.explain_parameters import ExplainParameters


class IExplainTask(metaclass=abc.ABCMeta):
Expand All @@ -18,7 +18,7 @@ class IExplainTask(metaclass=abc.ABCMeta):
def explain(
self,
dataset: DatasetEntity,
explain_parameters: InferenceParameters,
explain_parameters: ExplainParameters,
) -> DatasetEntity:
"""This is the method that is called upon explanation.

Expand Down
Loading