Skip to content

Commit 10fce79

Browse files
authored
Split export function and _export_model function for detection task (#1997)
* Split export function and _export_model function for detection task * Add basic unit test for detection task * Update change log * Reflect changes from #1976 * Fix unit test failure
1 parent 5652811 commit 10fce79

File tree

6 files changed

+431
-211
lines changed

6 files changed

+431
-211
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ All notable changes to this project will be documented in this file.
1414

1515
- Clean up and refactor the output of the OTX CLI (<https://github.com/openvinotoolkit/training_extensions/pull/1946>)
1616
- Enhance DetCon logic and SupCon for semantic segmentation(<https://github.com/openvinotoolkit/training_extensions/pull/1958>)
17+
- Detection task refactoring (<https://github.com/openvinotoolkit/training_extensions/pull/1955>)
1718
- Classification task refactoring (<https://github.com/openvinotoolkit/training_extensions/pull/1972>)
1819
- Extend OTX explain CLI (<https://github.com/openvinotoolkit/training_extensions/pull/1941>)
1920
- Segmentation task refactoring (<https://github.com/openvinotoolkit/training_extensions/pull/1977>)

otx/algorithms/detection/adapters/mmdet/task.py

+6-55
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from copy import deepcopy
2323
from typing import Any, Dict, Optional, Union
2424

25-
import numpy as np
2625
import torch
2726
from mmcv.runner import wrap_fp16_model
2827
from mmcv.utils import Config, ConfigDict, get_git_hash
@@ -52,7 +51,6 @@
5251
from otx.algorithms.common.utils import set_random_seed
5352
from otx.algorithms.common.utils.callback import InferenceProgressCallback
5453
from otx.algorithms.common.utils.data import get_dataset
55-
from otx.algorithms.common.utils.ir import embed_ir_model_data
5654
from otx.algorithms.common.utils.logger import get_logger
5755
from otx.algorithms.detection.adapters.mmdet.configurer import (
5856
DetectionConfigurer,
@@ -69,24 +67,20 @@
6967
)
7068
from otx.algorithms.detection.adapters.mmdet.utils.exporter import DetectionExporter
7169
from otx.algorithms.detection.task import OTXDetectionTask
72-
from otx.algorithms.detection.utils import get_det_model_api_configuration
7370
from otx.algorithms.detection.utils.data import adaptive_tile_params
7471
from otx.api.configuration import cfg_helper
75-
from otx.api.configuration.helper.utils import config_to_bytes, ids_to_strings
72+
from otx.api.configuration.helper.utils import ids_to_strings
7673
from otx.api.entities.datasets import DatasetEntity
7774
from otx.api.entities.explain_parameters import ExplainParameters
7875
from otx.api.entities.inference_parameters import InferenceParameters
7976
from otx.api.entities.model import (
8077
ModelEntity,
81-
ModelFormat,
82-
ModelOptimizationType,
8378
ModelPrecision,
8479
)
8580
from otx.api.entities.subset import Subset
8681
from otx.api.entities.task_environment import TaskEnvironment
8782
from otx.api.entities.train_parameters import default_progress_callback
8883
from otx.api.serialization.label_mapper import label_schema_to_bytes
89-
from otx.api.usecases.tasks.interfaces.export_interface import ExportType
9084
from otx.core.data import caching
9185

9286
logger = get_logger()
@@ -461,21 +455,12 @@ def hook(module, inp, outp):
461455
return prediction_results, metric
462456

463457
# pylint: disable=too-many-statements
464-
def export(
458+
def _export_model(
465459
self,
466-
export_type: ExportType,
467-
output_model: ModelEntity,
468-
precision: ModelPrecision = ModelPrecision.FP32,
469-
dump_features: bool = True,
460+
precision: ModelPrecision,
461+
dump_features: bool,
470462
):
471-
"""Export function of OTX Detection Task."""
472-
# copied from OTX inference_task.py
473-
logger.info("Exporting the model")
474-
if export_type != ExportType.OPENVINO:
475-
raise RuntimeError(f"not supported export type {export_type}")
476-
output_model.model_format = ModelFormat.OPENVINO
477-
output_model.optimization_type = ModelOptimizationType.MO
478-
463+
"""Main export function of OTX MMDetection Task."""
479464
self._init_task(export=True)
480465

481466
cfg = self.configure(False, "test", None)
@@ -506,41 +491,7 @@ def export(
506491
**export_options,
507492
)
508493

509-
outputs = results.get("outputs")
510-
logger.debug(f"results of run_task = {outputs}")
511-
if outputs is None:
512-
raise RuntimeError(results.get("msg"))
513-
514-
bin_file = outputs.get("bin")
515-
xml_file = outputs.get("xml")
516-
onnx_file = outputs.get("onnx")
517-
518-
ir_extra_data = get_det_model_api_configuration(
519-
self._task_environment.label_schema, self._task_type, self.confidence_threshold
520-
)
521-
embed_ir_model_data(xml_file, ir_extra_data)
522-
523-
if xml_file is None or bin_file is None or onnx_file is None:
524-
raise RuntimeError("invalid status of exporting. bin and xml or onnx should not be None")
525-
with open(bin_file, "rb") as f:
526-
output_model.set_data("openvino.bin", f.read())
527-
with open(xml_file, "rb") as f:
528-
output_model.set_data("openvino.xml", f.read())
529-
with open(onnx_file, "rb") as f:
530-
output_model.set_data("model.onnx", f.read())
531-
output_model.set_data(
532-
"confidence_threshold",
533-
np.array([self.confidence_threshold], dtype=np.float32).tobytes(),
534-
)
535-
output_model.set_data("config.json", config_to_bytes(self._hyperparams))
536-
output_model.precision = self._precision
537-
output_model.optimization_methods = self._optimization_methods
538-
output_model.has_xai = dump_features
539-
output_model.set_data(
540-
"label_schema.json",
541-
label_schema_to_bytes(self._task_environment.label_schema),
542-
)
543-
logger.info("Exporting completed")
494+
return results
544495

545496
def explain(
546497
self,

otx/algorithms/detection/task.py

+56-4
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,12 @@
3030
InferenceProgressCallback,
3131
TrainingProgressCallback,
3232
)
33+
from otx.algorithms.common.utils.ir import embed_ir_model_data
3334
from otx.algorithms.common.utils.logger import get_logger
3435
from otx.algorithms.detection.configs.base import DetectionConfig
36+
from otx.algorithms.detection.utils import get_det_model_api_configuration
3537
from otx.api.configuration import cfg_helper
36-
from otx.api.configuration.helper.utils import ids_to_strings
38+
from otx.api.configuration.helper.utils import config_to_bytes, ids_to_strings
3739
from otx.api.entities.annotation import Annotation
3840
from otx.api.entities.datasets import DatasetEntity
3941
from otx.api.entities.explain_parameters import ExplainParameters
@@ -50,7 +52,12 @@
5052
ScoreMetric,
5153
VisualizationType,
5254
)
53-
from otx.api.entities.model import ModelEntity, ModelPrecision
55+
from otx.api.entities.model import (
56+
ModelEntity,
57+
ModelFormat,
58+
ModelOptimizationType,
59+
ModelPrecision,
60+
)
5461
from otx.api.entities.model_template import TaskType
5562
from otx.api.entities.resultset import ResultSetEntity
5663
from otx.api.entities.scored_label import ScoredLabel
@@ -246,15 +253,60 @@ def _infer_model(
246253
"""Get inference results from dataset."""
247254
raise NotImplementedError
248255

249-
@abstractmethod
250256
def export(
251257
self,
252258
export_type: ExportType,
253259
output_model: ModelEntity,
254260
precision: ModelPrecision = ModelPrecision.FP32,
255261
dump_features: bool = True,
256262
):
257-
"""Export function of OTX Task."""
263+
"""Export function of OTX Detection Task."""
264+
logger.info("Exporting the model")
265+
if export_type != ExportType.OPENVINO:
266+
raise RuntimeError(f"not supported export type {export_type}")
267+
output_model.model_format = ModelFormat.OPENVINO
268+
output_model.optimization_type = ModelOptimizationType.MO
269+
270+
results = self._export_model(precision, dump_features)
271+
outputs = results.get("outputs")
272+
logger.debug(f"results of run_task = {outputs}")
273+
if outputs is None:
274+
raise RuntimeError(results.get("msg"))
275+
276+
bin_file = outputs.get("bin")
277+
xml_file = outputs.get("xml")
278+
onnx_file = outputs.get("onnx")
279+
280+
ir_extra_data = get_det_model_api_configuration(
281+
self._task_environment.label_schema, self._task_type, self.confidence_threshold
282+
)
283+
embed_ir_model_data(xml_file, ir_extra_data)
284+
285+
if xml_file is None or bin_file is None or onnx_file is None:
286+
raise RuntimeError("invalid status of exporting. bin and xml or onnx should not be None")
287+
with open(bin_file, "rb") as f:
288+
output_model.set_data("openvino.bin", f.read())
289+
with open(xml_file, "rb") as f:
290+
output_model.set_data("openvino.xml", f.read())
291+
with open(onnx_file, "rb") as f:
292+
output_model.set_data("model.onnx", f.read())
293+
output_model.set_data(
294+
"confidence_threshold",
295+
np.array([self.confidence_threshold], dtype=np.float32).tobytes(),
296+
)
297+
output_model.set_data("config.json", config_to_bytes(self._hyperparams))
298+
output_model.precision = self._precision
299+
output_model.optimization_methods = self._optimization_methods
300+
output_model.has_xai = dump_features
301+
output_model.set_data(
302+
"label_schema.json",
303+
label_schema_to_bytes(self._task_environment.label_schema),
304+
)
305+
logger.info("Exporting completed")
306+
307+
@abstractmethod
308+
def _export_model(self, precision: ModelPrecision, dump_features: bool):
309+
"""Main export function using training backend."""
258310
raise NotImplementedError
259311

260312
@abstractmethod

0 commit comments

Comments
 (0)