|
| 1 | +"""OTX DINO Class for mmdetection detectors.""" |
| 2 | + |
| 3 | +# Copyright (C) 2023 Intel Corporation |
| 4 | +# SPDX-License-Identifier: Apache-2.0 |
| 5 | +# |
| 6 | + |
| 7 | +from mmdet.models.builder import DETECTORS |
| 8 | + |
| 9 | +from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import ( |
| 10 | + ActivationMapHook, |
| 11 | + FeatureVectorHook, |
| 12 | +) |
| 13 | +from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled |
| 14 | +from otx.algorithms.common.utils.logger import get_logger |
| 15 | +from otx.algorithms.detection.adapters.mmdet.models.detectors import CustomDeformableDETR |
| 16 | + |
| 17 | +logger = get_logger() |
| 18 | + |
| 19 | + |
| 20 | +@DETECTORS.register_module() |
| 21 | +class CustomDINO(CustomDeformableDETR): |
| 22 | + """Custom DINO detector.""" |
| 23 | + |
| 24 | + def __init__(self, *args, task_adapt=None, **kwargs): |
| 25 | + super().__init__(*args, task_adapt=task_adapt, **kwargs) |
| 26 | + self._register_load_state_dict_pre_hook( |
| 27 | + self.load_state_dict_pre_hook, |
| 28 | + ) |
| 29 | + |
| 30 | + @staticmethod |
| 31 | + def load_state_dict_pre_hook(ckpt_dict, *args, **kwargs): |
| 32 | + """Modify mmdet3.x version's weights before weight loading.""" |
| 33 | + |
| 34 | + if list(ckpt_dict.keys())[0] == "level_embed": |
| 35 | + logger.info("----------------- CustomDINO.load_state_dict_pre_hook() called") |
| 36 | + # This ckpt_dict comes from mmdet3.x |
| 37 | + ckpt_dict["bbox_head.transformer.level_embeds"] = ckpt_dict.pop("level_embed") |
| 38 | + replaced_params = {} |
| 39 | + for param in ckpt_dict: |
| 40 | + new_param = None |
| 41 | + if "encoder" in param or "decoder" in param: |
| 42 | + new_param = "bbox_head.transformer." + param |
| 43 | + new_param = new_param.replace("self_attn", "attentions.0") |
| 44 | + new_param = new_param.replace("cross_attn", "attentions.1") |
| 45 | + new_param = new_param.replace("ffn", "ffns.0") |
| 46 | + elif param == "query_embedding.weight": |
| 47 | + new_param = "bbox_head." + param |
| 48 | + elif param == "dn_query_generator.label_embedding.weight": |
| 49 | + new_param = "bbox_head.transformer." + param |
| 50 | + elif "memory_trans" in param: |
| 51 | + new_param = "bbox_head.transformer." + param |
| 52 | + new_param = new_param.replace("memory_trans_fc", "enc_output") |
| 53 | + new_param = new_param.replace("memory_trans_norm", "enc_output_norm") |
| 54 | + if new_param is not None: |
| 55 | + replaced_params[param] = new_param |
| 56 | + |
| 57 | + for origin, new in replaced_params.items(): |
| 58 | + ckpt_dict[new] = ckpt_dict.pop(origin) |
| 59 | + |
| 60 | + |
| 61 | +if is_mmdeploy_enabled(): |
| 62 | + from mmdeploy.core import FUNCTION_REWRITER |
| 63 | + |
| 64 | + @FUNCTION_REWRITER.register_rewriter( |
| 65 | + "otx.algorithms.detection.adapters.mmdet.models.detectors.custom_dino_detector.CustomDINO.simple_test" |
| 66 | + ) |
| 67 | + def custom_dino__simple_test(ctx, self, img, img_metas, **kwargs): |
| 68 | + """Function for custom_dino__simple_test.""" |
| 69 | + height = int(img_metas[0]["img_shape"][0]) |
| 70 | + width = int(img_metas[0]["img_shape"][1]) |
| 71 | + img_metas[0]["batch_input_shape"] = (height, width) |
| 72 | + img_metas[0]["img_shape"] = (height, width, 3) |
| 73 | + feats = self.extract_feat(img) |
| 74 | + gt_bboxes = [None] * len(feats) |
| 75 | + gt_labels = [None] * len(feats) |
| 76 | + hidden_states, references, enc_output_class, enc_output_coord, _ = self.bbox_head.forward_transformer( |
| 77 | + feats, gt_bboxes, gt_labels, img_metas |
| 78 | + ) |
| 79 | + cls_scores, bbox_preds = self.bbox_head(hidden_states, references) |
| 80 | + bbox_results = self.bbox_head.get_bboxes( |
| 81 | + cls_scores, bbox_preds, enc_output_class, enc_output_coord, img_metas=img_metas, **kwargs |
| 82 | + ) |
| 83 | + |
| 84 | + if ctx.cfg["dump_features"]: |
| 85 | + feature_vector = FeatureVectorHook.func(feats) |
| 86 | + saliency_map = ActivationMapHook.func(cls_scores) |
| 87 | + return (*bbox_results, feature_vector, saliency_map) |
| 88 | + |
| 89 | + return bbox_results |
0 commit comments