Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OTX ATSS detector model interpreter & refactor interfaces #1047

Merged
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## \[Unreleased\]
### New features
- Add OTX ATSS detector model interpreter & refactor interfaces
(<https://github.com/openvinotoolkit/datumaro/pull/1047>)

### Enhancements
- Enhance import performance for built-in plugins
Expand Down
1 change: 1 addition & 0 deletions src/datumaro/components/abstracts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
# SPDX-License-Identifier: MIT

from .merger import *
from .model_interpreter import *
11 changes: 9 additions & 2 deletions src/datumaro/components/abstracts/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,24 @@
from abc import ABC, abstractmethod
from typing import Dict, Optional, Sequence, Type

from datumaro.components.annotation import Annotation
from datumaro.components.dataset_base import IDataset
from datumaro.components.dataset_item_storage import (
DatasetItemStorage,
DatasetItemStorageDatasetView,
)
from datumaro.components.media import MediaElement

__all__ = ["IMerger"]
__all__ = ["IMatcherContext", "IMergerContext"]


class IMerger(ABC):
class IMatcherContext(ABC):
@abstractmethod
def get_any_label_name(self, ann: Annotation, label_id: int) -> str:
raise NotImplementedError


class IMergerContext(IMatcherContext):
@abstractmethod
def merge_infos(self, sources: Sequence[IDataset]) -> Dict:
raise NotImplementedError
Expand Down
25 changes: 25 additions & 0 deletions src/datumaro/components/abstracts/model_interpreter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

from abc import ABC, abstractmethod

__all__ = ["IModelInterpreter"]


class IModelInterpreter(ABC):
@abstractmethod
def get_categories(self):
raise NotImplementedError("Function should be implemented.")

@abstractmethod
def process_outputs(self, inputs, outputs):
raise NotImplementedError("Function should be implemented.")

@abstractmethod
def normalize(self, inputs):
raise NotImplementedError("Function should be implemented.")

@abstractmethod
def resize(self, inputs):
raise NotImplementedError("Function should be implemented.")
77 changes: 67 additions & 10 deletions src/datumaro/components/annotations/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
#
# SPDX-License-Identifier: MIT

from typing import Optional
from typing import Optional, Union

import numpy as np
from attr import attrib, attrs

from datumaro.components.abstracts import IMerger
from datumaro.components.abstracts import IMergerContext
from datumaro.components.abstracts.merger import IMatcherContext
from datumaro.util.annotation_util import (
OKS,
approximate_line,
Expand All @@ -18,7 +19,8 @@
)

__all__ = [
"match_segments",
"match_segments_pair",
"match_segments_more_than_pair",
"AnnotationMatcher",
"LabelMatcher",
"ShapeMatcher",
Expand All @@ -34,13 +36,15 @@
]


def match_segments(
def match_segments_pair(
a_segms,
b_segms,
distance=segment_iou,
dist_thresh=1.0,
label_matcher=lambda a, b: a.label == b.label,
):
"""Match segments and return pairs of the two matched segments"""

assert callable(distance), distance
assert callable(label_matcher), label_matcher

Expand Down Expand Up @@ -95,9 +99,61 @@ def match_segments(
return matches, mispred, a_unmatched, b_unmatched


def match_segments_more_than_pair(
a_segms,
b_segms,
distance=segment_iou,
dist_thresh=1.0,
label_matcher=lambda a, b: a.label == b.label,
):
"""Match segments and return sets of the matched segments which can be more than two"""

assert callable(distance), distance
assert callable(label_matcher), label_matcher

# a_matches: indices of b_segms matched to a bboxes
# b_matches: indices of a_segms matched to b bboxes
a_matches = -np.ones(len(a_segms), dtype=int)
b_matches = -np.ones(len(b_segms), dtype=int)

distances = np.array([[distance(a, b) for b in b_segms] for a in a_segms])

# matches: boxes we succeeded to match completely
# mispred: boxes we succeeded to match, having label mismatch
matches = []
mispred = []

# It needs len(a_segms) > 0 and len(b_segms) > 0
if len(b_segms) > 0:
for a_idx, a_segm in enumerate(a_segms):
b_indices = np.argsort(
[not label_matcher(a_segm, b_segm) for b_segm in b_segms], kind="stable"
) # prioritize those with same label, keep score order
for b_idx in b_indices:
d = distances[a_idx, b_idx]
if d < dist_thresh:
continue

a_matches[a_idx] = b_idx
b_matches[b_idx] = a_idx

b_segm = b_segms[b_idx]

if label_matcher(a_segm, b_segm):
matches.append((a_segm, b_segm))
else:
mispred.append((a_segm, b_segm))

# *_umatched: boxes of (*) we failed to match
a_unmatched = [a_segms[i] for i, m in enumerate(a_matches) if m < 0]
b_unmatched = [b_segms[i] for i, m in enumerate(b_matches) if m < 0]

return matches, mispred, a_unmatched, b_unmatched


@attrs(kw_only=True)
class AnnotationMatcher:
_context: Optional[IMerger] = attrib(default=None)
_context: Optional[Union[IMatcherContext, IMergerContext]] = attrib(default=None)

def match_annotations(self, sources):
raise NotImplementedError()
Expand All @@ -106,8 +162,8 @@ def match_annotations(self, sources):
@attrs
class LabelMatcher(AnnotationMatcher):
def distance(self, a, b):
a_label = self._context._get_any_label_name(a, a.label)
b_label = self._context._get_any_label_name(b, b.label)
a_label = self._context.get_any_label_name(a, a.label)
b_label = self._context.get_any_label_name(b, b.label)
return a_label == b_label

def match_annotations(self, sources):
Expand All @@ -118,6 +174,7 @@ def match_annotations(self, sources):
class ShapeMatcher(AnnotationMatcher):
pairwise_dist = attrib(converter=float, default=0.9)
cluster_dist = attrib(converter=float, default=-1.0)
_match_segments = attrib(default=match_segments_pair)

def match_annotations(self, sources):
distance = self.distance
Expand Down Expand Up @@ -152,7 +209,7 @@ def _has_same_source(cluster, extra_id):
adjacent = {i: [] for i in id_segm} # id(sgm) -> [id(adj_sgm1), ...]
for a_idx, src_a in enumerate(sources):
for src_b in sources[a_idx + 1 :]:
matches, _, _, _ = match_segments(
matches, _, _, _ = self._match_segments(
src_a,
src_b,
dist_thresh=pairwise_dist,
Expand Down Expand Up @@ -194,8 +251,8 @@ def distance(self, a, b):
return segment_iou(a, b)

def label_matcher(self, a, b):
a_label = self._context._get_any_label_name(a, a.label)
b_label = self._context._get_any_label_name(b, b.label)
a_label = self._context.get_any_label_name(a, a.label)
b_label = self._context.get_any_label_name(b, b.label)
return a_label == b_label


Expand Down
8 changes: 4 additions & 4 deletions src/datumaro/components/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from datumaro.cli.util.project import generate_next_file_name
from datumaro.components.annotation import AnnotationType, LabelCategories
from datumaro.components.annotations.matcher import LineMatcher, PointsMatcher, match_segments
from datumaro.components.annotations.matcher import LineMatcher, PointsMatcher, match_segments_pair
from datumaro.components.dataset import Dataset
from datumaro.components.operations import (
compute_ann_statistics,
Expand Down Expand Up @@ -69,7 +69,7 @@ def match_labels(self, item_a, item_b):
def _match_segments(self, t, item_a, item_b):
a_boxes = self._get_ann_type(t, item_a)
b_boxes = self._get_ann_type(t, item_b)
return match_segments(a_boxes, b_boxes, dist_thresh=self.iou_threshold)
return match_segments_pair(a_boxes, b_boxes, dist_thresh=self.iou_threshold)

def match_polygons(self, item_a, item_b):
return self._match_segments(AnnotationType.polygon, item_a, item_b)
Expand All @@ -93,7 +93,7 @@ def match_points(self, item_a, item_b):
instance_map[id(ann)] = [inst, inst_bbox]
matcher = PointsMatcher(instance_map=instance_map)

return match_segments(
return match_segments_pair(
a_points, b_points, dist_thresh=self.iou_threshold, distance=matcher.distance
)

Expand All @@ -103,7 +103,7 @@ def match_lines(self, item_a, item_b):

matcher = LineMatcher()

return match_segments(
return match_segments_pair(
a_lines, b_lines, dist_thresh=self.iou_threshold, distance=matcher.distance
)

Expand Down
7 changes: 5 additions & 2 deletions src/datumaro/components/merge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import OrderedDict
from typing import Dict, Optional, Sequence, Type

from datumaro.components.abstracts.merger import IMerger
from datumaro.components.abstracts.merger import IMergerContext
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.dataset_base import IDataset
from datumaro.components.dataset_item_storage import DatasetItemStorageDatasetView
Expand All @@ -20,7 +20,7 @@
from datumaro.util import dump_json_file


class Merger(IMerger, CliPlugin):
class Merger(IMergerContext, CliPlugin):
"""Merge multiple datasets into one dataset"""

def __init__(self, **options):
Expand Down Expand Up @@ -104,3 +104,6 @@ def save_merge_report(self, path: str) -> None:
os.makedirs(os.path.dirname(path), exist_ok=True)

dump_json_file(path, errors, indent=True)

def get_any_label_name(self, ann, label_id):
raise NotImplementedError
2 changes: 1 addition & 1 deletion src/datumaro/components/merge/intersect_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ def _get_src_label_name(self, ann, label_id):
self._dataset_map[dataset_id][0].categories()[AnnotationType.label].items[label_id].name
)

def _get_any_label_name(self, ann, label_id):
def get_any_label_name(self, ann, label_id):
if label_id is None:
return None
try:
Expand Down
Loading