22
22
import subprocess # nosec B404
23
23
import tempfile
24
24
from glob import glob
25
- from pathlib import Path
26
25
from typing import Any , Dict , List , Optional , Tuple , Union
27
26
from warnings import warn
28
27
36
35
PostProcessingConfigurationCallback ,
37
36
)
38
37
from omegaconf import DictConfig , ListConfig
39
- from openvino .runtime import Core , serialize
40
38
from pytorch_lightning import Trainer
41
39
42
40
from otx .algorithms .anomaly .adapters .anomalib .callbacks import (
47
45
from otx .algorithms .anomaly .adapters .anomalib .data import OTXAnomalyDataModule
48
46
from otx .algorithms .anomaly .adapters .anomalib .logger import get_logger
49
47
from otx .algorithms .anomaly .configs .base .configuration import BaseAnomalyConfig
48
+ from otx .algorithms .common .utils import embed_ir_model_data
49
+ from otx .algorithms .common .utils .utils import embed_onnx_model_data
50
50
from otx .api .entities .datasets import DatasetEntity
51
51
from otx .api .entities .inference_parameters import InferenceParameters
52
52
from otx .api .entities .metrics import NullPerformance , Performance , ScoreMetric
@@ -296,6 +296,8 @@ def export(
296
296
self ._export_to_onnx (onnx_path )
297
297
298
298
if export_type == ExportType .ONNX :
299
+ self ._add_metadata_to_ir (onnx_path , export_type )
300
+
299
301
with open (onnx_path , "rb" ) as file :
300
302
output_model .set_data ("model.onnx" , file .read ())
301
303
else :
@@ -306,7 +308,7 @@ def export(
306
308
bin_file = glob (os .path .join (self .config .project .path , "*.bin" ))[0 ]
307
309
xml_file = glob (os .path .join (self .config .project .path , "*.xml" ))[0 ]
308
310
309
- self ._add_metadata_to_ir (xml_file )
311
+ self ._add_metadata_to_ir (xml_file , export_type )
310
312
311
313
with open (bin_file , "rb" ) as file :
312
314
output_model .set_data ("openvino.bin" , file .read ())
@@ -319,40 +321,51 @@ def export(
319
321
output_model .set_data ("label_schema.json" , label_schema_to_bytes (self .task_environment .label_schema ))
320
322
self ._set_metadata (output_model )
321
323
322
- def _add_metadata_to_ir (self , xml_file : str ) -> None :
323
- """Adds the metadata to the model IR.
324
+ def _add_metadata_to_ir (self , model_file : str , export_type : ExportType ) -> None :
325
+ """Adds the metadata to the model IR or ONNX .
324
326
325
327
Adds the metadata to the model IR. So that it can be used with the new modelAPI.
326
328
This is because the metadata.json is not used by the new modelAPI.
327
329
# TODO CVS-114640
328
330
# TODO: Step 1. Remove metadata.json when modelAPI becomes the default inference method.
329
- # TODO: Step 2. Remove this function when Anomalib is upgraded as the model graph will contain the required ops
331
+ # TODO: Step 2. Update this function when Anomalib is upgraded as the model graph will contain the required ops
330
332
# TODO: Step 3. Update modelAPI to remove pre/post-processing steps when Anomalib version is upgraded.
331
333
"""
332
334
metadata = self ._get_metadata_dict ()
333
- core = Core ()
334
- model = core .read_model (xml_file )
335
+ extra_model_data : Dict [Tuple [str , str ], Any ] = {}
335
336
for key , value in metadata .items ():
336
- if key == "transform" :
337
+ if key in ( "transform" , "min" , "max" ) :
337
338
continue
338
- model . set_rt_info ( value , [ "model_info" , key ])
339
+ extra_model_data [( "model_info" , key )] = value
339
340
# Add transforms
340
341
if "transform" in metadata :
341
342
for transform_dict in metadata ["transform" ]["transform" ]["transforms" ]:
342
343
transform = transform_dict .pop ("__class_fullname__" )
343
344
if transform == "Normalize" :
344
- model .set_rt_info (self ._serialize_list (transform_dict ["mean" ]), ["model_info" , "mean_values" ])
345
- model .set_rt_info (self ._serialize_list (transform_dict ["std" ]), ["model_info" , "scale_values" ])
345
+ extra_model_data [("model_info" , "mean_values" )] = self ._serialize_list (
346
+ [x * 255.0 for x in transform_dict ["mean" ]]
347
+ )
348
+ extra_model_data [("model_info" , "scale_values" )] = self ._serialize_list (
349
+ [x * 255.0 for x in transform_dict ["std" ]]
350
+ )
346
351
elif transform == "Resize" :
347
- model . set_rt_info ( transform_dict [ "height" ], [ "model_info" , "orig_height" ])
348
- model . set_rt_info ( transform_dict [ "width" ], [ "model_info" , "orig_width" ])
352
+ extra_model_data [( "model_info" , "orig_height" )] = transform_dict [ "height" ]
353
+ extra_model_data [( "model_info" , "orig_width" )] = transform_dict [ "width" ]
349
354
else :
350
355
warn (f"Transform { transform } is not supported currently" )
351
- model .set_rt_info ("AnomalyDetection" , ["model_info" , "model_type" ])
352
- tmp_xml_path = Path (Path (xml_file ).parent ) / "tmp.xml"
353
- serialize (model , str (tmp_xml_path ))
354
- tmp_xml_path .rename (xml_file )
355
- Path (str (tmp_xml_path .parent / tmp_xml_path .stem ) + ".bin" ).unlink ()
356
+ # Since we only need the diff of max and min, we fuse the min and max into one op
357
+ if "min" in metadata and "max" in metadata :
358
+ extra_model_data [("model_info" , "normalization_scale" )] = metadata ["max" ] - metadata ["min" ]
359
+
360
+ extra_model_data [("model_info" , "reverse_input_channels" )] = False
361
+ extra_model_data [("model_info" , "model_type" )] = "AnomalyDetection"
362
+ extra_model_data [("model_info" , "labels" )] = "Normal Anomaly"
363
+ if export_type == ExportType .OPENVINO :
364
+ embed_ir_model_data (model_file , extra_model_data )
365
+ elif export_type == ExportType .ONNX :
366
+ embed_onnx_model_data (model_file , extra_model_data )
367
+ else :
368
+ raise RuntimeError (f"not supported export type { export_type } " )
356
369
357
370
def _serialize_list (self , arr : Union [Tuple , List ]) -> str :
358
371
"""Converts a list to space separated string."""
0 commit comments