diff --git a/CHANGELOG.md b/CHANGELOG.md index d52174b0668..426664f56c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -110,6 +110,8 @@ All notable changes to this project will be documented in this file. () - Fix wrong model name in converter & template () +- Fix RTMDet Inst Explain Mode + () ## \[v2.1.0\] diff --git a/src/otx/algo/detection/base_models/single_stage_detector.py b/src/otx/algo/detection/base_models/single_stage_detector.py index c83626aa29c..c848eaae675 100644 --- a/src/otx/algo/detection/base_models/single_stage_detector.py +++ b/src/otx/algo/detection/base_models/single_stage_detector.py @@ -11,12 +11,14 @@ from typing import TYPE_CHECKING +import torch + +from otx.algo.instance_segmentation.heads.rtmdet_inst_head import RTMDetInstSepBNHead from otx.algo.modules.base_module import BaseModule from otx.algo.utils.mmengine_utils import InstanceData from otx.core.data.entity.detection import DetBatchDataEntity if TYPE_CHECKING: - import torch from torch import Tensor, nn @@ -210,6 +212,19 @@ def export( backbone_feat = self.extract_feat(batch_inputs) bbox_head_feat = self.bbox_head.forward(backbone_feat) feature_vector = self.feature_vector_fn(backbone_feat) + + if isinstance(self.bbox_head, RTMDetInstSepBNHead): + # create dummy saliency map as its implemented in ModelAPI + saliency_map = torch.zeros(1) + bboxes, labels, masks = self.bbox_head.export(backbone_feat, batch_img_metas, rescale=rescale) # type: ignore[misc] + return { + "bboxes": bboxes, + "labels": labels, + "masks": masks, + "feature_vector": feature_vector, + "saliency_map": saliency_map, + } + saliency_map = self.explain_fn(bbox_head_feat[0]) bboxes, labels = self.bbox_head.export(backbone_feat, batch_img_metas, rescale=rescale) return { diff --git a/src/otx/algo/instance_segmentation/heads/__init__.py b/src/otx/algo/instance_segmentation/heads/__init__.py index 3b333fb1cd2..ea5515b5c9f 100644 --- a/src/otx/algo/instance_segmentation/heads/__init__.py +++ b/src/otx/algo/instance_segmentation/heads/__init__.py @@ -7,7 +7,7 @@ from .fcn_mask_head import FCNMaskHead from .roi_head_tv import TVRoIHeads from .rpn_head import RPNHead -from .rtmdet_ins_head import RTMDetInsSepBNHead +from .rtmdet_inst_head import RTMDetInstSepBNHead __all__ = [ "Shared2FCBBoxHead", @@ -16,5 +16,5 @@ "FCNMaskHead", "TVRoIHeads", "RPNHead", - "RTMDetInsSepBNHead", + "RTMDetInstSepBNHead", ] diff --git a/src/otx/algo/instance_segmentation/heads/rtmdet_ins_head.py b/src/otx/algo/instance_segmentation/heads/rtmdet_inst_head.py similarity index 99% rename from src/otx/algo/instance_segmentation/heads/rtmdet_ins_head.py rename to src/otx/algo/instance_segmentation/heads/rtmdet_inst_head.py index 77539bf22d3..f1a2a559c55 100644 --- a/src/otx/algo/instance_segmentation/heads/rtmdet_ins_head.py +++ b/src/otx/algo/instance_segmentation/heads/rtmdet_inst_head.py @@ -46,7 +46,7 @@ # mypy: disable-error-code="call-overload, index, override, attr-defined, misc" -class RTMDetInsHead(RTMDetHead): +class RTMDetInstHead(RTMDetHead): """Detection Head of RTMDet-Ins. Args: @@ -764,7 +764,7 @@ def forward(self, features: tuple[Tensor, ...]) -> Tensor: return self.projection(mask_features) -class RTMDetInsSepBNHead(RTMDetInsHead): +class RTMDetInstSepBNHead(RTMDetInstHead): """Detection Head of RTMDet-Ins with sep-bn layers. Args: diff --git a/src/otx/algo/instance_segmentation/rtmdet_inst.py b/src/otx/algo/instance_segmentation/rtmdet_inst.py index 8e7f4dd2369..4b093d0a341 100644 --- a/src/otx/algo/instance_segmentation/rtmdet_inst.py +++ b/src/otx/algo/instance_segmentation/rtmdet_inst.py @@ -18,7 +18,7 @@ from otx.algo.common.utils.samplers import PseudoSampler from otx.algo.detection.base_models import SingleStageDetector from otx.algo.detection.necks import CSPNeXtPAFPN -from otx.algo.instance_segmentation.heads import RTMDetInsSepBNHead +from otx.algo.instance_segmentation.heads import RTMDetInstSepBNHead from otx.algo.instance_segmentation.losses import DiceLoss from otx.algo.modules.norm import build_norm_layer from otx.core.config.data import TileConfig @@ -155,7 +155,7 @@ def _build_model(self, num_classes: int) -> SingleStageDetector: activation=partial(nn.SiLU, inplace=True), ) - bbox_head = RTMDetInsSepBNHead( + bbox_head = RTMDetInstSepBNHead( num_classes=num_classes, in_channels=96, stacked_convs=2, diff --git a/tests/integration/api/test_xai.py b/tests/integration/api/test_xai.py index 91daf52e4f3..1249867db9f 100644 --- a/tests/integration/api/test_xai.py +++ b/tests/integration/api/test_xai.py @@ -110,7 +110,7 @@ def test_predict_with_explain( if "dino" in model_name: pytest.skip("DINO is not supported.") - if any(keyword in recipe for keyword in ["rtmdet_inst_tiny", "maskdino", "maskrcnn_r50_tv"]): + if any(keyword in recipe for keyword in ["maskdino", "maskrcnn_r50_tv"]): # TODO(Eugene): inst-seg models not fully support yet. pytest.skip(f"There's issue with inst-seg: {recipe}. Skip for now.")