Skip to content

Commit f974b41

Browse files
authored
Enable DINO to OTX - Step 2. Upgrade Deformable DETR to DINO (#2266)
* Add DINO * Modify docstrings * Add mmengine to detection requirements * Add unit tests * Add intg test * Update CHANGELOG.md * Change description of config files for DINO * Modify unit tests * Reflect reviews * Reflect Reviews * Update unit tests
1 parent bf30d6e commit f974b41

File tree

20 files changed

+2554
-9
lines changed

20 files changed

+2554
-9
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ All notable changes to this project will be documented in this file.
1010
- Add custom max iou assigner to prevent CPU OOM when large annotations are used (<https://github.com/openvinotoolkit/training_extensions/pull/2228>)
1111
- Auto train type detection for Semi-SL, Self-SL and Incremental: "--train-type" now is optional (https://github.com/openvinotoolkit/training_extensions/pull/2195)
1212
- Add new object detector Deformable DETR (<https://github.com/openvinotoolkit/training_extensions/pull/2249>)
13+
- Add new object detecotr DINO(<https://github.com/openvinotoolkit/training_extensions/pull/2266>)
1314

1415
### Enhancements
1516

otx/algorithms/detection/adapters/mmdet/models/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
# SPDX-License-Identifier: Apache-2.0
44
#
55

6-
from . import assigners, backbones, dense_heads, detectors, heads, losses, necks, roi_heads
6+
from . import assigners, backbones, dense_heads, detectors, heads, layers, losses, necks, roi_heads
77

8-
__all__ = ["assigners", "backbones", "dense_heads", "detectors", "heads", "losses", "necks", "roi_heads"]
8+
__all__ = ["assigners", "backbones", "dense_heads", "detectors", "heads", "layers", "losses", "necks", "roi_heads"]

otx/algorithms/detection/adapters/mmdet/models/detectors/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from .custom_atss_detector import CustomATSS
77
from .custom_deformable_detr_detector import CustomDeformableDETR
8+
from .custom_dino_detector import CustomDINO
89
from .custom_maskrcnn_detector import CustomMaskRCNN
910
from .custom_maskrcnn_tile_optimized import CustomMaskRCNNTileOptimized
1011
from .custom_single_stage_detector import CustomSingleStageDetector
@@ -18,6 +19,7 @@
1819
__all__ = [
1920
"CustomATSS",
2021
"CustomDeformableDETR",
22+
"CustomDINO",
2123
"CustomMaskRCNN",
2224
"CustomSingleStageDetector",
2325
"CustomTwoStageDetector",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""OTX DINO Class for mmdetection detectors."""
2+
3+
# Copyright (C) 2023 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
#
6+
7+
from mmdet.models.builder import DETECTORS
8+
9+
from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import (
10+
ActivationMapHook,
11+
FeatureVectorHook,
12+
)
13+
from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
14+
from otx.algorithms.common.utils.logger import get_logger
15+
from otx.algorithms.detection.adapters.mmdet.models.detectors import CustomDeformableDETR
16+
17+
logger = get_logger()
18+
19+
20+
@DETECTORS.register_module()
21+
class CustomDINO(CustomDeformableDETR):
22+
"""Custom DINO detector."""
23+
24+
def __init__(self, *args, task_adapt=None, **kwargs):
25+
super().__init__(*args, task_adapt=task_adapt, **kwargs)
26+
self._register_load_state_dict_pre_hook(
27+
self.load_state_dict_pre_hook,
28+
)
29+
30+
@staticmethod
31+
def load_state_dict_pre_hook(ckpt_dict, *args, **kwargs):
32+
"""Modify mmdet3.x version's weights before weight loading."""
33+
34+
if list(ckpt_dict.keys())[0] == "level_embed":
35+
logger.info("----------------- CustomDINO.load_state_dict_pre_hook() called")
36+
# This ckpt_dict comes from mmdet3.x
37+
ckpt_dict["bbox_head.transformer.level_embeds"] = ckpt_dict.pop("level_embed")
38+
replaced_params = {}
39+
for param in ckpt_dict:
40+
new_param = None
41+
if "encoder" in param or "decoder" in param:
42+
new_param = "bbox_head.transformer." + param
43+
new_param = new_param.replace("self_attn", "attentions.0")
44+
new_param = new_param.replace("cross_attn", "attentions.1")
45+
new_param = new_param.replace("ffn", "ffns.0")
46+
elif param == "query_embedding.weight":
47+
new_param = "bbox_head." + param
48+
elif param == "dn_query_generator.label_embedding.weight":
49+
new_param = "bbox_head.transformer." + param
50+
elif "memory_trans" in param:
51+
new_param = "bbox_head.transformer." + param
52+
new_param = new_param.replace("memory_trans_fc", "enc_output")
53+
new_param = new_param.replace("memory_trans_norm", "enc_output_norm")
54+
if new_param is not None:
55+
replaced_params[param] = new_param
56+
57+
for origin, new in replaced_params.items():
58+
ckpt_dict[new] = ckpt_dict.pop(origin)
59+
60+
61+
if is_mmdeploy_enabled():
62+
from mmdeploy.core import FUNCTION_REWRITER
63+
64+
@FUNCTION_REWRITER.register_rewriter(
65+
"otx.algorithms.detection.adapters.mmdet.models.detectors.custom_dino_detector.CustomDINO.simple_test"
66+
)
67+
def custom_dino__simple_test(ctx, self, img, img_metas, **kwargs):
68+
"""Function for custom_dino__simple_test."""
69+
height = int(img_metas[0]["img_shape"][0])
70+
width = int(img_metas[0]["img_shape"][1])
71+
img_metas[0]["batch_input_shape"] = (height, width)
72+
img_metas[0]["img_shape"] = (height, width, 3)
73+
feats = self.extract_feat(img)
74+
gt_bboxes = [None] * len(feats)
75+
gt_labels = [None] * len(feats)
76+
hidden_states, references, enc_output_class, enc_output_coord, _ = self.bbox_head.forward_transformer(
77+
feats, gt_bboxes, gt_labels, img_metas
78+
)
79+
cls_scores, bbox_preds = self.bbox_head(hidden_states, references)
80+
bbox_results = self.bbox_head.get_bboxes(
81+
cls_scores, bbox_preds, enc_output_class, enc_output_coord, img_metas=img_metas, **kwargs
82+
)
83+
84+
if ctx.cfg["dump_features"]:
85+
feature_vector = FeatureVectorHook.func(feats)
86+
saliency_map = ActivationMapHook.func(cls_scores)
87+
return (*bbox_results, feature_vector, saliency_map)
88+
89+
return bbox_results

otx/algorithms/detection/adapters/mmdet/models/heads/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,27 @@
66
from .cross_dataset_detector_head import CrossDatasetDetectorHead
77
from .custom_anchor_generator import SSDAnchorGeneratorClustered
88
from .custom_atss_head import CustomATSSHead, CustomATSSHeadTrackingLossDynamics
9+
from .custom_dino_head import CustomDINOHead
910
from .custom_fcn_mask_head import CustomFCNMaskHead
1011
from .custom_retina_head import CustomRetinaHead
1112
from .custom_roi_head import CustomRoIHead
1213
from .custom_ssd_head import CustomSSDHead
1314
from .custom_vfnet_head import CustomVFNetHead
1415
from .custom_yolox_head import CustomYOLOXHead
16+
from .detr_head import DETRHeadExtension
1517

1618
__all__ = [
1719
"CrossDatasetDetectorHead",
1820
"SSDAnchorGeneratorClustered",
1921
"CustomATSSHead",
22+
"CustomDINOHead",
2023
"CustomFCNMaskHead",
2124
"CustomRetinaHead",
2225
"CustomSSDHead",
2326
"CustomRoIHead",
2427
"CustomVFNetHead",
2528
"CustomYOLOXHead",
29+
"DETRHeadExtension",
2630
# Loss dynamics tracking
2731
"CustomATSSHeadTrackingLossDynamics",
2832
]

0 commit comments

Comments
 (0)