|
22 | 22 | from copy import deepcopy
|
23 | 23 | from typing import Any, Dict, Optional, Union
|
24 | 24 |
|
25 |
| -import numpy as np |
26 | 25 | import torch
|
27 | 26 | from mmcv.runner import wrap_fp16_model
|
28 | 27 | from mmcv.utils import Config, ConfigDict, get_git_hash
|
|
52 | 51 | from otx.algorithms.common.utils import set_random_seed
|
53 | 52 | from otx.algorithms.common.utils.callback import InferenceProgressCallback
|
54 | 53 | from otx.algorithms.common.utils.data import get_dataset
|
55 |
| -from otx.algorithms.common.utils.ir import embed_ir_model_data |
56 | 54 | from otx.algorithms.common.utils.logger import get_logger
|
57 | 55 | from otx.algorithms.detection.adapters.mmdet.configurer import (
|
58 | 56 | DetectionConfigurer,
|
|
69 | 67 | )
|
70 | 68 | from otx.algorithms.detection.adapters.mmdet.utils.exporter import DetectionExporter
|
71 | 69 | from otx.algorithms.detection.task import OTXDetectionTask
|
72 |
| -from otx.algorithms.detection.utils import get_det_model_api_configuration |
73 | 70 | from otx.algorithms.detection.utils.data import adaptive_tile_params
|
74 | 71 | from otx.api.configuration import cfg_helper
|
75 |
| -from otx.api.configuration.helper.utils import config_to_bytes, ids_to_strings |
| 72 | +from otx.api.configuration.helper.utils import ids_to_strings |
76 | 73 | from otx.api.entities.datasets import DatasetEntity
|
77 | 74 | from otx.api.entities.explain_parameters import ExplainParameters
|
78 | 75 | from otx.api.entities.inference_parameters import InferenceParameters
|
79 | 76 | from otx.api.entities.model import (
|
80 | 77 | ModelEntity,
|
81 |
| - ModelFormat, |
82 |
| - ModelOptimizationType, |
83 | 78 | ModelPrecision,
|
84 | 79 | )
|
85 | 80 | from otx.api.entities.subset import Subset
|
86 | 81 | from otx.api.entities.task_environment import TaskEnvironment
|
87 | 82 | from otx.api.entities.train_parameters import default_progress_callback
|
88 | 83 | from otx.api.serialization.label_mapper import label_schema_to_bytes
|
89 |
| -from otx.api.usecases.tasks.interfaces.export_interface import ExportType |
90 | 84 | from otx.core.data import caching
|
91 | 85 |
|
92 | 86 | logger = get_logger()
|
@@ -461,21 +455,12 @@ def hook(module, inp, outp):
|
461 | 455 | return prediction_results, metric
|
462 | 456 |
|
463 | 457 | # pylint: disable=too-many-statements
|
464 |
| - def export( |
| 458 | + def _export_model( |
465 | 459 | self,
|
466 |
| - export_type: ExportType, |
467 |
| - output_model: ModelEntity, |
468 |
| - precision: ModelPrecision = ModelPrecision.FP32, |
469 |
| - dump_features: bool = True, |
| 460 | + precision: ModelPrecision, |
| 461 | + dump_features: bool, |
470 | 462 | ):
|
471 |
| - """Export function of OTX Detection Task.""" |
472 |
| - # copied from OTX inference_task.py |
473 |
| - logger.info("Exporting the model") |
474 |
| - if export_type != ExportType.OPENVINO: |
475 |
| - raise RuntimeError(f"not supported export type {export_type}") |
476 |
| - output_model.model_format = ModelFormat.OPENVINO |
477 |
| - output_model.optimization_type = ModelOptimizationType.MO |
478 |
| - |
| 463 | + """Main export function of OTX MMDetection Task.""" |
479 | 464 | self._init_task(export=True)
|
480 | 465 |
|
481 | 466 | cfg = self.configure(False, "test", None)
|
@@ -506,41 +491,7 @@ def export(
|
506 | 491 | **export_options,
|
507 | 492 | )
|
508 | 493 |
|
509 |
| - outputs = results.get("outputs") |
510 |
| - logger.debug(f"results of run_task = {outputs}") |
511 |
| - if outputs is None: |
512 |
| - raise RuntimeError(results.get("msg")) |
513 |
| - |
514 |
| - bin_file = outputs.get("bin") |
515 |
| - xml_file = outputs.get("xml") |
516 |
| - onnx_file = outputs.get("onnx") |
517 |
| - |
518 |
| - ir_extra_data = get_det_model_api_configuration( |
519 |
| - self._task_environment.label_schema, self._task_type, self.confidence_threshold |
520 |
| - ) |
521 |
| - embed_ir_model_data(xml_file, ir_extra_data) |
522 |
| - |
523 |
| - if xml_file is None or bin_file is None or onnx_file is None: |
524 |
| - raise RuntimeError("invalid status of exporting. bin and xml or onnx should not be None") |
525 |
| - with open(bin_file, "rb") as f: |
526 |
| - output_model.set_data("openvino.bin", f.read()) |
527 |
| - with open(xml_file, "rb") as f: |
528 |
| - output_model.set_data("openvino.xml", f.read()) |
529 |
| - with open(onnx_file, "rb") as f: |
530 |
| - output_model.set_data("model.onnx", f.read()) |
531 |
| - output_model.set_data( |
532 |
| - "confidence_threshold", |
533 |
| - np.array([self.confidence_threshold], dtype=np.float32).tobytes(), |
534 |
| - ) |
535 |
| - output_model.set_data("config.json", config_to_bytes(self._hyperparams)) |
536 |
| - output_model.precision = self._precision |
537 |
| - output_model.optimization_methods = self._optimization_methods |
538 |
| - output_model.has_xai = dump_features |
539 |
| - output_model.set_data( |
540 |
| - "label_schema.json", |
541 |
| - label_schema_to_bytes(self._task_environment.label_schema), |
542 |
| - ) |
543 |
| - logger.info("Exporting completed") |
| 494 | + return results |
544 | 495 |
|
545 | 496 | def explain(
|
546 | 497 | self,
|
|
0 commit comments