From 56489b3ee592d23c7465f3a7c2ec5a75f04140e9 Mon Sep 17 00:00:00 2001 From: Ben Hoff Date: Tue, 8 Oct 2019 07:39:39 -0400 Subject: [PATCH] allow security segmentation models to be used in auto annotation --- cvat/apps/auto_annotation/model_loader.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/cvat/apps/auto_annotation/model_loader.py b/cvat/apps/auto_annotation/model_loader.py index 73d33d81b4fd..15a7c792efeb 100644 --- a/cvat/apps/auto_annotation/model_loader.py +++ b/cvat/apps/auto_annotation/model_loader.py @@ -31,14 +31,19 @@ def __init__(self, model, weights): iter_inputs = iter(network.inputs) self._input_blob_name = next(iter_inputs) + self._input_info_name = '' self._output_blob_name = next(iter(network.outputs)) self._require_image_info = False + info_names = ('image_info', 'im_info') + # NOTE: handeling for the inclusion of `image_info` in OpenVino2019 - if 'image_info' in network.inputs: + if any(s in network.inputs for s in info_names): self._require_image_info = True - if self._input_blob_name == 'image_info': + self._input_info_name = set(network.inputs).intersection(info_names) + self._input_info_name = self._input_info_name.pop() + if self._input_blob_name in info_names: self._input_blob_name = next(iter_inputs) self._net = plugin.load(network=network, num_requests=2) @@ -56,7 +61,7 @@ def infer(self, image): info[0, 1] = w # frame number info[0, 2] = 1 - inputs['image_info'] = info + inputs[self._input_info_name] = info results = self._net.infer(inputs) if len(results) == 1: