Skip to content

Commit 36c1e51

Browse files
committed
Enable otx deploy
1 parent e9924e6 commit 36c1e51

File tree

2 files changed

+99
-16
lines changed

2 files changed

+99
-16
lines changed

src/otx/algorithms/visual_prompting/tasks/openvino.py

+62-2
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,24 @@
1414
# See the License for the specific language governing permissions
1515
# and limitations under the License.
1616

17+
import io
18+
import json
1719
import os
1820
import time
1921
from pathlib import Path
2022
from typing import Any, Dict, List, Optional, Tuple, Union
23+
from zipfile import ZipFile
2124

2225
import attr
2326
import numpy as np
2427
from openvino.model_zoo.model_api.adapters import OpenvinoAdapter, create_core
2528
from openvino.model_zoo.model_api.models import Model
2629

27-
import otx.algorithms.visual_prompting.adapters.openvino.model_wrappers # noqa: F401
2830
from otx.algorithms.common.utils.logger import get_logger
2931
from otx.algorithms.common.utils.utils import get_default_async_reqs_num
32+
from otx.algorithms.visual_prompting.adapters.openvino import ( # noqa: F401
33+
model_wrappers,
34+
)
3035
from otx.algorithms.visual_prompting.adapters.pytorch_lightning.datasets.dataset import (
3136
OTXVisualPromptingDataset,
3237
get_transform,
@@ -45,7 +50,9 @@
4550
from otx.api.entities.optimization_parameters import OptimizationParameters
4651
from otx.api.entities.resultset import ResultSetEntity
4752
from otx.api.entities.task_environment import TaskEnvironment
53+
from otx.api.serialization.label_mapper import LabelSchemaMapper
4854
from otx.api.usecases.evaluation.metrics_helper import MetricsHelper
55+
from otx.api.usecases.exportable_code import demo
4956
from otx.api.usecases.exportable_code.inference import BaseInferencer
5057
from otx.api.usecases.exportable_code.prediction_to_annotation_converter import (
5158
VisualPromptingToAnnotationConverter,
@@ -260,7 +267,60 @@ def evaluate(self, output_resultset: ResultSetEntity, evaluation_metric: Optiona
260267

261268
def deploy(self, output_model: ModelEntity) -> None:
262269
"""Deploy function of OpenVINOVisualPromptingTask."""
263-
print("deploy start!")
270+
logger.info("Deploying the model")
271+
if self.model is None:
272+
raise RuntimeError("deploy failed, model is None")
273+
274+
work_dir = os.path.dirname(demo.__file__)
275+
parameters = {}
276+
parameters["converter_type"] = f"{self.task_type}"
277+
parameters["model_parameters"] = self.inferencer.configuration # type: ignore
278+
parameters["model_parameters"]["labels"] = LabelSchemaMapper.forward(self.task_environment.label_schema) # type: ignore # noqa: E501
279+
280+
zip_buffer = io.BytesIO()
281+
with ZipFile(zip_buffer, "w") as arch:
282+
# model files
283+
arch.writestr(
284+
os.path.join("model", "visual_prompting_image_encoder.xml"),
285+
self.model.get_data("visual_prompting_image_encoder.xml"),
286+
)
287+
arch.writestr(
288+
os.path.join("model", "visual_prompting_image_encoder.bin"),
289+
self.model.get_data("visual_prompting_image_encoder.bin"),
290+
)
291+
arch.writestr(
292+
os.path.join("model", "visual_prompting_decoder.xml"),
293+
self.model.get_data("visual_prompting_decoder.xml"),
294+
)
295+
arch.writestr(
296+
os.path.join("model", "visual_prompting_decoder.bin"),
297+
self.model.get_data("visual_prompting_decoder.bin"),
298+
)
299+
arch.writestr(
300+
os.path.join("model", "config.json"),
301+
json.dumps(parameters, ensure_ascii=False, indent=4),
302+
)
303+
# model_wrappers files
304+
for root, _, files in os.walk(os.path.dirname(model_wrappers.__file__)):
305+
if "__pycache__" in root:
306+
continue
307+
for file in files:
308+
file_path = os.path.join(root, file)
309+
arch.write(
310+
file_path,
311+
os.path.join(
312+
"python",
313+
"model_wrappers",
314+
file_path.split("model_wrappers/")[0],
315+
),
316+
)
317+
# other python files
318+
arch.write(os.path.join(work_dir, "requirements.txt"), os.path.join("python", "requirements.txt"))
319+
arch.write(os.path.join(work_dir, "LICENSE"), os.path.join("python", "LICENSE"))
320+
arch.write(os.path.join(work_dir, "demo.py"), os.path.join("python", "demo.py"))
321+
arch.write(os.path.join(work_dir, "README.md"), os.path.join(".", "README.md"))
322+
output_model.exportable_code = zip_buffer.getvalue()
323+
logger.info("Deploying completed")
264324

265325
def optimize(
266326
self,

tests/unit/algorithms/visual_prompting/tasks/test_openvino.py

+37-14
Original file line numberDiff line numberDiff line change
@@ -4,40 +4,43 @@
44
# SPDX-License-Identifier: Apache-2.0
55
#
66

7+
from copy import deepcopy
8+
79
import numpy as np
8-
import torch
910
import pytest
10-
from otx.api.entities.dataset_item import DatasetItemEntity
11-
from otx.algorithms.visual_prompting.adapters.pytorch_lightning.datasets.dataset import OTXVisualPromptingDataset
11+
import torch
1212
from openvino.model_zoo.model_api.models import Model
13+
14+
from otx.algorithms.visual_prompting.adapters.pytorch_lightning.datasets.dataset import (
15+
OTXVisualPromptingDataset,
16+
)
1317
from otx.algorithms.visual_prompting.configs.base import VisualPromptingBaseConfig
1418
from otx.algorithms.visual_prompting.tasks.openvino import (
1519
OpenVINOVisualPromptingInferencer,
1620
OpenVINOVisualPromptingTask,
1721
)
18-
from otx.api.entities.annotation import (
19-
Annotation,
20-
AnnotationSceneEntity,
21-
AnnotationSceneKind,
22-
)
22+
from otx.api.configuration.configurable_parameters import ConfigurableParameters
23+
from otx.api.entities.annotation import Annotation
24+
from otx.api.entities.dataset_item import DatasetItemEntity
2325
from otx.api.entities.datasets import DatasetEntity
2426
from otx.api.entities.inference_parameters import InferenceParameters
2527
from otx.api.entities.label import LabelEntity
28+
from otx.api.entities.label_schema import LabelSchemaEntity
2629
from otx.api.entities.metrics import Performance, ScoreMetric
30+
from otx.api.entities.model import ModelConfiguration, ModelEntity
2731
from otx.api.entities.resultset import ResultSetEntity
2832
from otx.api.entities.scored_label import ScoredLabel
2933
from otx.api.entities.shapes.polygon import Point, Polygon
3034
from otx.api.usecases.evaluation.metrics_helper import MetricsHelper
35+
from otx.api.usecases.exportable_code.prediction_to_annotation_converter import (
36+
VisualPromptingToAnnotationConverter,
37+
)
3138
from otx.api.utils.shape_factory import ShapeFactory
32-
3339
from tests.test_suite.e2e_test_system import e2e_pytest_unit
3440
from tests.unit.algorithms.visual_prompting.test_helpers import (
3541
generate_visual_prompting_dataset,
3642
init_environment,
3743
)
38-
from otx.api.usecases.exportable_code.prediction_to_annotation_converter import (
39-
VisualPromptingToAnnotationConverter,
40-
)
4144

4245

4346
class TestOpenVINOVisualPromptingInferencer:
@@ -159,8 +162,16 @@ def test_forward_decoder(self):
159162

160163

161164
class TestOpenVINOVisualPromptingTask:
165+
@pytest.fixture
166+
def otx_model(self):
167+
model_configuration = ModelConfiguration(
168+
configurable_parameters=ConfigurableParameters(header="header", description="description"),
169+
label_schema=LabelSchemaEntity(),
170+
)
171+
return ModelEntity(train_dataset=DatasetEntity(), configuration=model_configuration)
172+
162173
@pytest.fixture(autouse=True)
163-
def setup(self, mocker):
174+
def setup(self, mocker, otx_model):
164175
"""Load the OpenVINOVisualPromptingTask."""
165176
mocker.patch("otx.algorithms.visual_prompting.tasks.openvino.OpenvinoAdapter")
166177
mocker.patch.object(Model, "create_model")
@@ -174,7 +185,8 @@ def setup(self, mocker):
174185
{"image_encoder": "", "decoder": ""},
175186
)
176187

177-
self.task_environment.model = mocker.patch("otx.api.entities.model.ModelEntity")
188+
# self.task_environment.model = mocker.patch("otx.api.entities.model.ModelEntity")
189+
self.task_environment.model = otx_model
178190
mocker.patch.object(OpenVINOVisualPromptingTask, "load_inferencer", return_value=visual_prompting_ov_inferencer)
179191
self.visual_prompting_ov_task = OpenVINOVisualPromptingTask(task_environment=self.task_environment)
180192

@@ -220,3 +232,14 @@ def test_evaluate(self, mocker):
220232
self.visual_prompting_ov_task.evaluate(result_set)
221233

222234
assert result_set.performance.score.value == 0.1
235+
236+
@e2e_pytest_unit
237+
def test_deploy(self):
238+
output_model = deepcopy(self.task_environment.model)
239+
self.visual_prompting_ov_task.model.set_data("visual_prompting_image_encoder.bin", b"image_encoder_bin")
240+
self.visual_prompting_ov_task.model.set_data("visual_prompting_image_encoder.xml", b"image_encoder_xml")
241+
self.visual_prompting_ov_task.model.set_data("visual_prompting_decoder.bin", b"decoder_bin")
242+
self.visual_prompting_ov_task.model.set_data("visual_prompting_decoder.xml", b"deocder_xml")
243+
self.visual_prompting_ov_task.deploy(output_model)
244+
245+
assert output_model.exportable_code is not None

0 commit comments

Comments
 (0)