Skip to content

Commit 40094c2

Browse files
author
Evgeny Tsykunov
authored
Enhance XAI (CLI, ExplainParameters, test) (#1941)
* xai_enhance * Update IExplainTask signature * pylint fix * logger * reduce num_iters, rename DetClassProbMapHook * reply to comments * rebase + docs * rebase
1 parent 9d6edc7 commit 40094c2

File tree

25 files changed

+810
-344
lines changed

25 files changed

+810
-344
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ All notable changes to this project will be documented in this file.
1212

1313
- Clean up and refactor the output of the OTX CLI (<https://github.com/openvinotoolkit/training_extensions/pull/1946>)
1414
- Enhance DetCon logic and SupCon for semantic segmentation(<https://github.com/openvinotoolkit/training_extensions/pull/1958>)
15+
- Extend OTX explain CLI (<https://github.com/openvinotoolkit/training_extensions/pull/1941>)
1516

1617
### Bug fixes
1718

docs/source/guide/get_started/quick_start_guide/cli_commands.rst

+6-2
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ The command below will evaluate the trained model on the provided dataset:
397397
Explanation
398398
***********
399399

400-
``otx explain`` runs the explainable AI (XAI) algorithm of a model on the specific dataset. It helps explain the model's decision-making process in a way that is easily understood by humans.
400+
``otx explain`` runs the explainable AI (XAI) algorithm on a specific model-dataset pair. It helps explain the model's decision-making process in a way that is easily understood by humans.
401401

402402
With the ``--help`` command, you can list additional information, such as its parameters common to all model templates:
403403

@@ -422,8 +422,12 @@ With the ``--help`` command, you can list additional information, such as its pa
422422
Load model weights from previously saved checkpoint.
423423
--explain-algorithm EXPLAIN_ALGORITHM
424424
Explain algorithm name, currently support ['activationmap', 'eigencam', 'classwisesaliencymap']. For Openvino task, default method will be selected.
425+
--process-saliency-maps PROCESS_SALIENCY_MAPS
426+
Processing of saliency map includes (1) resizing to input image resolution and (2) applying a colormap. Depending on the number of targets to explain, this might take significant time.
427+
--explain-all-classes EXPLAIN_ALL_CLASSES
428+
Provides explanations for all classes. Otherwise, explains only predicted classes. This feature is supported by algorithms that can generate explanations per each class.
425429
--overlay-weight OVERLAY_WEIGHT
426-
Weight of the saliency map when overlaying the saliency map.
430+
Weight of the saliency map when overlaying the input image with saliency map.
427431
428432
429433
The command below will generate saliency maps (heatmaps with red colored areas of focus) of the trained model on the provided dataset and save the resulting images to ``save-explanation-to`` path:

otx/algorithms/classification/tasks/inference.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from otx.algorithms.common.utils import embed_ir_model_data
3737
from otx.algorithms.common.utils.logger import get_logger
3838
from otx.api.entities.datasets import DatasetEntity
39+
from otx.api.entities.explain_parameters import ExplainParameters
3940
from otx.api.entities.inference_parameters import (
4041
InferenceParameters,
4142
default_progress_callback,
@@ -162,7 +163,7 @@ def infer(
162163
def explain(
163164
self,
164165
dataset: DatasetEntity,
165-
explain_parameters: Optional[InferenceParameters] = None,
166+
explain_parameters: Optional[ExplainParameters] = None,
166167
) -> DatasetEntity:
167168
"""Main explain function of OTX Classification Task."""
168169
logger.info("called explain()")

otx/algorithms/classification/tasks/openvino.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
)
4040
from otx.api.entities.annotation import AnnotationSceneEntity
4141
from otx.api.entities.datasets import DatasetEntity
42+
from otx.api.entities.explain_parameters import ExplainParameters
4243
from otx.api.entities.inference_parameters import (
4344
InferenceParameters,
4445
default_progress_callback,
@@ -254,7 +255,7 @@ def infer(
254255
def explain(
255256
self,
256257
dataset: DatasetEntity,
257-
explain_parameters: Optional[InferenceParameters] = None,
258+
explain_parameters: Optional[ExplainParameters] = None,
258259
) -> DatasetEntity:
259260
"""Explain function of ClassificationOpenVINOTask."""
260261

@@ -269,6 +270,12 @@ def explain(
269270
dataset_size = len(dataset)
270271
for i, dataset_item in enumerate(dataset, 1):
271272
predicted_scene, _, saliency_map, _, _ = self.inferencer.predict(dataset_item.numpy)
273+
if saliency_map is None:
274+
raise RuntimeError(
275+
"There is no Saliency Map in OpenVINO IR model output. "
276+
"Please export model to OpenVINO IR with dump_features"
277+
)
278+
272279
item_labels = predicted_scene.annotations[0].get_labels()
273280
dataset_item.append_labels(item_labels)
274281
add_saliency_maps_to_dataset_item(

otx/algorithms/common/tasks/base_task.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from otx.algorithms.common.utils import UncopiableDefaultDict
3232
from otx.algorithms.common.utils.logger import get_logger
3333
from otx.api.entities.datasets import DatasetEntity
34+
from otx.api.entities.explain_parameters import ExplainParameters
3435
from otx.api.entities.inference_parameters import InferenceParameters
3536
from otx.api.entities.label import LabelEntity
3637
from otx.api.entities.metrics import MetricsGroup
@@ -208,7 +209,7 @@ def export(
208209
def explain(
209210
self,
210211
dataset: DatasetEntity,
211-
explain_parameters: Optional[InferenceParameters] = None,
212+
explain_parameters: Optional[ExplainParameters] = None,
212213
) -> DatasetEntity:
213214
"""Main explain function of OTX Task."""
214215
raise NotImplementedError

otx/algorithms/detection/adapters/mmdet/hooks/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
# SPDX-License-Identifier: Apache-2.0
44
#
55

6-
from .det_saliency_map_hook import DetSaliencyMapHook
6+
from .det_class_probability_map_hook import DetClassProbabilityMapHook
77

8-
__all__ = ["DetSaliencyMapHook"]
8+
__all__ = ["DetClassProbabilityMapHook"]

otx/algorithms/detection/adapters/mmdet/hooks/det_saliency_map_hook.py otx/algorithms/detection/adapters/mmdet/hooks/det_class_probability_map_hook.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
# pylint: disable=too-many-locals
2727

2828

29-
class DetSaliencyMapHook(BaseRecordingForwardHook):
29+
class DetClassProbabilityMapHook(BaseRecordingForwardHook):
3030
"""Saliency map hook for object detection models."""
3131

3232
def __init__(self, module: torch.nn.Module) -> None:
@@ -116,7 +116,7 @@ def forward_single(x, cls_convs, conv_cls):
116116
else:
117117
raise NotImplementedError(
118118
"Not supported detection head provided. "
119-
"DetSaliencyMapHook supports only the following single stage detectors: "
119+
"DetClassProbabilityMap supports only the following single stage detectors: "
120120
"YOLOXHead, ATSSHead, SSDHead, VFNetHead."
121121
)
122122
return cls_scores

otx/algorithms/detection/adapters/mmdet/models/detectors/custom_atss_detector.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
1616
from otx.algorithms.common.utils.logger import get_logger
1717
from otx.algorithms.common.utils.task_adapt import map_class_names
18-
from otx.algorithms.detection.adapters.mmdet.hooks.det_saliency_map_hook import (
19-
DetSaliencyMapHook,
18+
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
19+
DetClassProbabilityMapHook,
2020
)
2121

2222
from .l2sp_detector_mixin import L2SPDetectorMixin
@@ -99,7 +99,7 @@ def custom_atss__simple_test(ctx, self, img, img_metas, **kwargs):
9999
if ctx.cfg["dump_features"]:
100100
feature_vector = FeatureVectorHook.func(feat)
101101
cls_scores = outs[0]
102-
saliency_map = DetSaliencyMapHook(self).func(feature_map=cls_scores, cls_scores_provided=True)
102+
saliency_map = DetClassProbabilityMapHook(self).func(feature_map=cls_scores, cls_scores_provided=True)
103103
return (*bbox_results, feature_vector, saliency_map)
104104

105105
return bbox_results

otx/algorithms/detection/adapters/mmdet/models/detectors/custom_single_stage_detector.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
1616
from otx.algorithms.common.utils.logger import get_logger
1717
from otx.algorithms.common.utils.task_adapt import map_class_names
18-
from otx.algorithms.detection.adapters.mmdet.hooks.det_saliency_map_hook import (
19-
DetSaliencyMapHook,
18+
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
19+
DetClassProbabilityMapHook,
2020
)
2121

2222
from .l2sp_detector_mixin import L2SPDetectorMixin
@@ -157,7 +157,7 @@ def custom_single_stage_detector__simple_test(ctx, self, img, img_metas, **kwarg
157157
if ctx.cfg["dump_features"]:
158158
feature_vector = FeatureVectorHook.func(feat)
159159
cls_scores = outs[0]
160-
saliency_map = DetSaliencyMapHook(self).func(cls_scores, cls_scores_provided=True)
160+
saliency_map = DetClassProbabilityMapHook(self).func(cls_scores, cls_scores_provided=True)
161161
return (*bbox_results, feature_vector, saliency_map)
162162

163163
return bbox_results

otx/algorithms/detection/adapters/mmdet/models/detectors/custom_yolox_detector.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
1616
from otx.algorithms.common.utils.logger import get_logger
1717
from otx.algorithms.common.utils.task_adapt import map_class_names
18-
from otx.algorithms.detection.adapters.mmdet.hooks.det_saliency_map_hook import (
19-
DetSaliencyMapHook,
18+
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
19+
DetClassProbabilityMapHook,
2020
)
2121

2222
from .l2sp_detector_mixin import L2SPDetectorMixin
@@ -137,7 +137,7 @@ def custom_yolox__simple_test(ctx, self, img, img_metas, **kwargs):
137137
if ctx.cfg["dump_features"]:
138138
feature_vector = FeatureVectorHook.func(feat)
139139
cls_scores = outs[0]
140-
saliency_map = DetSaliencyMapHook(self).func(cls_scores, cls_scores_provided=True)
140+
saliency_map = DetClassProbabilityMapHook(self).func(cls_scores, cls_scores_provided=True)
141141
return (*bbox_results, feature_vector, saliency_map)
142142

143143
return bbox_results

otx/algorithms/detection/adapters/mmdet/task.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@
6060
SemiSLDetectionConfigurer,
6161
)
6262
from otx.algorithms.detection.adapters.mmdet.datasets import ImageTilingDataset
63-
from otx.algorithms.detection.adapters.mmdet.hooks.det_saliency_map_hook import (
64-
DetSaliencyMapHook,
63+
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
64+
DetClassProbabilityMapHook,
6565
)
6666
from otx.algorithms.detection.adapters.mmdet.utils.builder import build_detector
6767
from otx.algorithms.detection.adapters.mmdet.utils.config_utils import (
@@ -74,6 +74,7 @@
7474
from otx.api.configuration import cfg_helper
7575
from otx.api.configuration.helper.utils import config_to_bytes, ids_to_strings
7676
from otx.api.entities.datasets import DatasetEntity
77+
from otx.api.entities.explain_parameters import ExplainParameters
7778
from otx.api.entities.inference_parameters import InferenceParameters
7879
from otx.api.entities.model import (
7980
ModelEntity,
@@ -392,7 +393,7 @@ def hook(module, inp, outp):
392393
if isinstance(raw_model, TwoStageDetector):
393394
saliency_hook = ActivationMapHook(feature_model)
394395
else:
395-
saliency_hook = DetSaliencyMapHook(feature_model)
396+
saliency_hook = DetClassProbabilityMapHook(feature_model)
396397

397398
if not dump_features:
398399
feature_vector_hook: Union[nullcontext, BaseRecordingForwardHook] = nullcontext()
@@ -541,12 +542,12 @@ def export(
541542
def explain(
542543
self,
543544
dataset: DatasetEntity,
544-
explain_parameters: Optional[InferenceParameters] = None,
545+
explain_parameters: Optional[ExplainParameters] = None,
545546
) -> DatasetEntity:
546547
"""Main explain function of MMDetectionTask."""
547548

548549
explainer_hook_selector = {
549-
"classwisesaliencymap": DetSaliencyMapHook,
550+
"classwisesaliencymap": DetClassProbabilityMapHook,
550551
"eigencam": EigenCamHook,
551552
"activationmap": ActivationMapHook,
552553
}

otx/algorithms/detection/adapters/openvino/task.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from otx.api.configuration.helper.utils import config_to_bytes
4242
from otx.api.entities.annotation import AnnotationSceneEntity
4343
from otx.api.entities.datasets import DatasetEntity
44+
from otx.api.entities.explain_parameters import ExplainParameters
4445
from otx.api.entities.inference_parameters import (
4546
InferenceParameters,
4647
default_progress_callback,
@@ -434,7 +435,7 @@ def infer(
434435
def explain(
435436
self,
436437
dataset: DatasetEntity,
437-
explain_parameters: Optional[InferenceParameters] = None,
438+
explain_parameters: Optional[ExplainParameters] = None,
438439
) -> DatasetEntity:
439440
"""Explain function of OpenVINODetectionTask."""
440441
logger.info("Start OpenVINO explain")
@@ -453,6 +454,11 @@ def explain(
453454
dataset_item.append_annotations(predicted_scene.annotations)
454455
update_progress_callback(int(i / dataset_size * 100), None)
455456
_, saliency_map = features
457+
if saliency_map is None:
458+
raise RuntimeError(
459+
"There is no Saliency Map in OpenVINO IR model output. "
460+
"Please export model to OpenVINO IR with dump_features"
461+
)
456462

457463
labels = self.task_environment.get_labels().copy()
458464
if saliency_map.shape[0] == len(labels) + 1:

otx/algorithms/detection/task.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from otx.api.configuration.helper.utils import ids_to_strings
3737
from otx.api.entities.annotation import Annotation
3838
from otx.api.entities.datasets import DatasetEntity
39+
from otx.api.entities.explain_parameters import ExplainParameters
3940
from otx.api.entities.id import ID
4041
from otx.api.entities.inference_parameters import InferenceParameters
4142
from otx.api.entities.label import Domain, LabelEntity
@@ -259,7 +260,7 @@ def export(
259260
def explain(
260261
self,
261262
dataset: DatasetEntity,
262-
explain_parameters: Optional[InferenceParameters] = None,
263+
explain_parameters: Optional[ExplainParameters] = None,
263264
) -> DatasetEntity:
264265
"""Main explain function of OTX Task."""
265266
raise NotImplementedError
+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""This module define the Explain entity."""
2+
# Copyright (C) 2023 Intel Corporation
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
6+
7+
from dataclasses import dataclass
8+
from typing import Any, Callable, Optional
9+
10+
11+
# pylint: disable=unused-argument
12+
def default_progress_callback(progress: int, score: Optional[float] = None):
13+
"""This is the default progress callback for OptimizationParameters."""
14+
15+
16+
@dataclass
17+
class ExplainParameters:
18+
"""Explain parameters.
19+
20+
Attributes:
21+
explainer: Explain algorithm to be used in explanation mode.
22+
Will be converted automatically to lowercase.
23+
process_saliency_maps: Processing of saliency map includes (1) resize to input image resolution
24+
and (2) apply a colormap.
25+
explain_predicted_classes: Provides explanations only for predicted classes.
26+
Otherwise, explain all classes.
27+
"""
28+
29+
update_progress: Callable[[int, Optional[float]], Any] = default_progress_callback
30+
31+
explainer: str = ""
32+
process_saliency_maps: bool = False
33+
explain_predicted_classes: bool = True

otx/api/entities/inference_parameters.py

-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ class InferenceParameters:
3636
is_evaluation: bool = False
3737
update_progress: Callable[[int, Optional[float]], Any] = default_progress_callback
3838

39-
# TODO(negvet): use separate ExplainParameters dataclass for this
4039
explainer: str = ""
4140
process_saliency_maps: bool = False
4241
explain_predicted_classes: bool = True

otx/api/usecases/tasks/interfaces/explain_interface.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import abc
99

1010
from otx.api.entities.datasets import DatasetEntity
11-
from otx.api.entities.inference_parameters import InferenceParameters
11+
from otx.api.entities.explain_parameters import ExplainParameters
1212

1313

1414
class IExplainTask(metaclass=abc.ABCMeta):
@@ -18,7 +18,7 @@ class IExplainTask(metaclass=abc.ABCMeta):
1818
def explain(
1919
self,
2020
dataset: DatasetEntity,
21-
explain_parameters: InferenceParameters,
21+
explain_parameters: ExplainParameters,
2222
) -> DatasetEntity:
2323
"""This is the method that is called upon explanation.
2424

0 commit comments

Comments
 (0)