forked from openvinotoolkit/training_extensions
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_detection.py
135 lines (110 loc) · 4.51 KB
/
test_detection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
"""Unit tests for detection model module."""
from __future__ import annotations
from typing import TYPE_CHECKING
from unittest.mock import create_autospec
import pytest
import torch
from importlib_resources import files
from lightning.pytorch.cli import ReduceLROnPlateau
from omegaconf import OmegaConf
from otx.algo.detection.atss import MobileNetV2ATSS
from otx.algo.explain.explain_algo import feature_vector_fn
from otx.core.metrics.fmeasure import FMeasureCallable
from otx.core.model.detection import OTXDetectionModel
from otx.core.types.export import TaskLevelExportParameters
from torch.optim import Optimizer
if TYPE_CHECKING:
from omegaconf.dictconfig import DictConfig
class TestOTXDetectionModel:
@pytest.fixture()
def mock_optimizer(self):
return lambda _: create_autospec(Optimizer)
@pytest.fixture()
def mock_scheduler(self):
return lambda _: create_autospec([ReduceLROnPlateau])
@pytest.fixture(
params=[
{
"confidence_threshold": 0.35,
"state_dict": {},
},
{
"hyper_parameters": {"best_confidence_threshold": 0.35},
"state_dict": {},
},
],
ids=["v1", "v2"],
)
def mock_ckpt(self, request):
return request.param
@pytest.fixture()
def config(self) -> DictConfig:
cfg_path = files("otx") / "algo" / "detection" / "mmconfigs" / "yolox_tiny.yaml"
return OmegaConf.load(cfg_path)
@pytest.fixture()
def otx_model(self) -> MobileNetV2ATSS:
return MobileNetV2ATSS(label_info=1)
def test_configure_metric_with_ckpt(
self,
mock_optimizer,
mock_scheduler,
mock_ckpt,
) -> None:
model = OTXDetectionModel(
label_info=1,
torch_compile=False,
optimizer=mock_optimizer,
scheduler=mock_scheduler,
metric=FMeasureCallable,
)
model.on_load_checkpoint(mock_ckpt)
assert model.hparams["best_confidence_threshold"] == 0.35
def test_create_model(self, otx_model) -> None:
mmdet_model = otx_model._create_model()
assert mmdet_model is not None
assert isinstance(mmdet_model, torch.nn.Module)
def test_get_num_anchors(self, otx_model):
num_anchors = otx_model.get_num_anchors()
assert isinstance(num_anchors, list)
assert all(isinstance(n, int) for n in num_anchors)
def test_get_explain_fn(self, otx_model):
otx_model.explain_mode = True
explain_fn = otx_model.get_explain_fn()
assert callable(explain_fn)
def test_forward_explain_detection(self, otx_model, fxt_det_data_entity):
otx_model.model.feature_vector_fn = feature_vector_fn
otx_model.model.explain_fn = otx_model.get_explain_fn()
inputs = fxt_det_data_entity[2]
inputs.images = torch.randn(1, 3, 64, 64)
result = otx_model._forward_explain_detection(otx_model.model, inputs, mode="predict")
assert "predictions" in result
assert "feature_vector" in result
assert "saliency_map" in result
def test_customize_inputs(self, otx_model, fxt_det_data_entity) -> None:
output_data = otx_model._customize_inputs(fxt_det_data_entity[2])
assert output_data["mode"] == "loss"
assert output_data["entity"] == fxt_det_data_entity[2]
def test_forward_explain(self, otx_model, fxt_det_data_entity):
inputs = fxt_det_data_entity[2]
otx_model.training = False
otx_model.explain_mode = True
outputs = otx_model.forward_explain(inputs)
assert outputs.has_xai_outputs
assert outputs.feature_vector is not None
assert outputs.saliency_map is not None
def test_reset_restore_model_forward(self, otx_model):
otx_model.explain_mode = True
initial_model_forward = otx_model.model.forward
otx_model._reset_model_forward()
assert otx_model.original_model_forward is not None
assert str(otx_model.model.forward) != str(otx_model.original_model_forward)
otx_model._restore_model_forward()
assert otx_model.original_model_forward is None
assert str(otx_model.model.forward) == str(initial_model_forward)
def test_export_parameters(self, otx_model):
parameters = otx_model._export_parameters
assert isinstance(parameters, TaskLevelExportParameters)
assert parameters.task_type == "detection"