Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MissingAnnotationDetection transform #1049

Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### New features
- Add OTX ATSS detector model interpreter & refactor interfaces
(<https://github.com/openvinotoolkit/datumaro/pull/1047>)
- Add MissingAnnotationDetection transform
(<https://github.com/openvinotoolkit/datumaro/pull/1049>)

### Enhancements
- Enhance import performance for built-in plugins
Expand Down
5 changes: 3 additions & 2 deletions src/datumaro/components/annotations/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
#
# SPDX-License-Identifier: MIT

from typing import Optional, Union
from typing import List, Optional, Union

import numpy as np
from attr import attrib, attrs

from datumaro.components.abstracts import IMergerContext
from datumaro.components.abstracts.merger import IMatcherContext
from datumaro.components.annotation import Annotation
from datumaro.util.annotation_util import (
OKS,
approximate_line,
Expand Down Expand Up @@ -176,7 +177,7 @@ class ShapeMatcher(AnnotationMatcher):
cluster_dist = attrib(converter=float, default=-1.0)
_match_segments = attrib(default=match_segments_pair)

def match_annotations(self, sources):
def match_annotations(self, sources: List[List[Annotation]]) -> List[List[Annotation]]:
distance = self.distance
label_matcher = self.label_matcher
pairwise_dist = self.pairwise_dist
Expand Down
33 changes: 24 additions & 9 deletions src/datumaro/components/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
#
# SPDX-License-Identifier: MIT

from typing import Generator, List

import numpy as np

from datumaro.components.annotation import AnnotationType, LabelCategories
from datumaro.components.annotation import Annotation, AnnotationType, LabelCategories
from datumaro.components.cli_plugin import CliPlugin
from datumaro.components.dataset_base import DatasetItem, IDataset
from datumaro.components.transformer import Transform
from datumaro.util import take_by

Expand All @@ -32,13 +35,19 @@ def type_check(self, item):


class ModelTransform(Transform):
def __init__(self, extractor, launcher, batch_size=1, append_annotation=False):
def __init__(
self,
extractor: IDataset,
launcher: Launcher,
batch_size: int = 1,
append_annotation: bool = False,
):
super().__init__(extractor)
self._launcher = launcher
self._batch_size = batch_size
self._append_annotation = append_annotation

def __iter__(self):
def __iter__(self) -> Generator[DatasetItem, None, None]:
for batch in take_by(self._extractor, self._batch_size):
inputs = []
for item in batch:
Expand All @@ -48,11 +57,17 @@ def __iter__(self):
inputs = np.array(inputs)
inference = self._launcher.launch(inputs)

for item, annotations in zip(batch, inference):
self._check_annotations(annotations)
if self._append_annotation:
annotations = item.annotations + annotations
yield self.wrap_item(item, annotations=annotations)
for item in self._yield_item(batch, inference):
yield item

def _yield_item(
self, batch: List[DatasetItem], inference: List[List[Annotation]]
) -> Generator[DatasetItem, None, None]:
for item, annotations in zip(batch, inference):
self._check_annotations(annotations)
if self._append_annotation:
annotations = item.annotations + annotations
yield self.wrap_item(item, annotations=annotations)

def get_subset(self, name):
subset = self._extractor.get_subset(name)
Expand All @@ -75,7 +90,7 @@ def transform_item(self, item):
annotations = self._launcher.launch(inputs)[0]
return self.wrap_item(item, annotations=annotations)

def _check_annotations(self, annotations):
def _check_annotations(self, annotations: List[Annotation]):
labels_count = len(self.categories().get(AnnotationType.label, LabelCategories()).items)

for ann in annotations:
Expand Down
149 changes: 149 additions & 0 deletions src/datumaro/plugins/missing_annotation_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

from typing import Generator, List, Optional

from datumaro.components.abstracts.merger import IMatcherContext
from datumaro.components.annotation import Annotation, AnnotationType, LabelCategories
from datumaro.components.annotations.matcher import BboxMatcher, match_segments_more_than_pair
from datumaro.components.dataset_base import DatasetItem, IDataset
from datumaro.components.launcher import Launcher, ModelTransform


class MissingAnnotationDetection(ModelTransform):
"""This class is used to find annotations that are missing from the ground truth annotations.

To accomplish this, it generates model predictions for the dataset using the given launcher.
However, not all of these model predictions can be considered missing annotations since the dataset
already contains ground truth annotations and some of the model predictions can be duplicated with them.
To identify the missing annotations from the model predictions, this class filters out the predictions
that spatially overlap with the ground truth annotations.

For example, the follwing example will produce ``[Bbox(1, 1, 1, 1, label=1, attributes={"score": 0.5})]`` as
the missing annotations since ``Bbox(0, 0, 1, 1, label=1, attributes={"score": 1.0})`` is overlapped with the
ground-truth annotation, ``Bbox(0, 0, 1, 1, label=0)`` (``label_agnostic_matching=True``)

.. code-block:: python

ground_truth_annotations = [
Bbox(0, 0, 1, 1, label=0),
Bbox(1, 0, 1, 1, label=1),
Bbox(0, 1, 1, 1, label=2),
]
model_predictions = [
Bbox(0, 0, 1, 1, label=1, attributes={"score": 1.0}),
Bbox(1, 1, 1, 1, label=1, attributes={"score": 0.5}),
]

Args:
extractor: The dataset used to find missing labeled annotations.
launcher: The launcher used to generate model predictions from the dataset.
batch_size: The size of the batches used during processing.
pairwise_dist: The distance metric used to measure the distance between two annotations.
Typically, the distance metric is Intersection over Union (IoU), which is bounded between 0 and 1.
score_threshold: The minimum score required for an annotation to be considered
a candidate for missing annotations.
label_agnostic_matching: If set to false, annotations with different labels are not matched
to determine their spatial overlap. In the above example, ``label_agnostic_matching=False``
will produce ``model_predictions`` as is since ``Bbox(0, 0, 1, 1, label=1, attributes={"score": 1.0})``
has different label with ``Bbox(0, 0, 1, 1, label=0)``.
"""

def __init__(
self,
extractor: IDataset,
launcher: Launcher,
batch_size: int = 1,
pairwise_dist: float = 0.75,
score_threshold: Optional[float] = None,
label_agnostic_matching: bool = True,
):
super().__init__(extractor, launcher, batch_size, append_annotation=False)
self._score_threshold = score_threshold

class LabelAgnosticMatcherContext(IMatcherContext):
def get_any_label_name(self, ann: Annotation, label_id: int) -> str:
return ""

label_categories: LabelCategories = self.categories()[AnnotationType.label]

class LabelGnosticMatcherContext(IMatcherContext):
wonjuleee marked this conversation as resolved.
Show resolved Hide resolved
def get_any_label_name(self, ann: Annotation, label_id: int) -> str:
return label_categories[label_id]

self._support_matchers = {
AnnotationType.bbox: BboxMatcher(
pairwise_dist=pairwise_dist,
context=LabelAgnosticMatcherContext()
if label_agnostic_matching
else LabelGnosticMatcherContext(),
match_segments=match_segments_more_than_pair,
),
wonjuleee marked this conversation as resolved.
Show resolved Hide resolved
}

def _yield_item(
self, batch: List[DatasetItem], inference: List[List[Annotation]]
) -> Generator[DatasetItem, None, None]:
for item, annotations in zip(batch, inference):
self._check_annotations(annotations)
yield self.wrap_item(
item,
annotations=self._find_missing_anns(
gt_anns=item.annotations,
pseudo_anns=self._apply_score_threshold(annotations),
),
)

def _apply_score_threshold(self, annotations: List[Annotation]) -> List[Annotation]:
if self._score_threshold is None:
return annotations

return [
ann for ann in annotations if ann.attributes.get("score", 1.0) > self._score_threshold
]

def _find_missing_anns(
self, gt_anns: List[Annotation], pseudo_anns: List[Annotation]
) -> List[Annotation]:
for ann in pseudo_anns:
ann.attributes["_pseudo_label_"] = True
wonjuleee marked this conversation as resolved.
Show resolved Hide resolved

missing_labeled_anns = []
for ann_type, matcher in self._support_matchers.items():
clusters = matcher.match_annotations(
[
[ann for ann in gt_anns if ann.type == ann_type],
[ann for ann in pseudo_anns if ann.type == ann_type],
]
)
for cluster in clusters:
ann = self._pick_missing_ann_from_cluster(cluster)
if ann is not None:
missing_labeled_anns.append(ann)

return missing_labeled_anns

@staticmethod
def _pick_missing_ann_from_cluster(cluster: List[Annotation]) -> Optional[Annotation]:
pseudo_label_anns = []
gt_label_anns = []

for ann in cluster:
if "_pseudo_label_" in ann.attributes:
pseudo_label_anns.append(ann)
else:
gt_label_anns.append(ann)

if len(gt_label_anns) > 0:
return None

max_score = float("-inf")
max_ann = None
for ann in pseudo_label_anns:
score = ann.attributes.get("score", -1)
if score > max_score:
max_score = score
max_ann = ann
wonjuleee marked this conversation as resolved.
Show resolved Hide resolved

return max_ann
12 changes: 12 additions & 0 deletions src/datumaro/plugins/specs.json
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,18 @@
"plugin_type": "Exporter",
"extra_deps": []
},
{
"import_path": "datumaro.plugins.missing_annotation_detection.MissingAnnotationDetection",
"plugin_name": "missing_annotation_detection",
"plugin_type": "Transform",
"extra_deps": []
},
{
"import_path": "datumaro.components.launcher.ModelTransform",
"plugin_name": "model",
"plugin_type": "Transform",
"extra_deps": []
},
{
"import_path": "datumaro.plugins.transforms.AnnsToLabels",
"plugin_name": "anns_to_labels",
Expand Down
89 changes: 89 additions & 0 deletions tests/unit/test_missing_annotation_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT

from unittest.mock import MagicMock

import numpy as np
import pytest

from datumaro.components.annotation import AnnotationType, Bbox, LabelCategories
from datumaro.components.dataset import Dataset, eager_mode
from datumaro.components.dataset_base import DatasetItem
from datumaro.components.media import Image

from tests.requirements import Requirements, mark_requirement


class MissingAnnotationDetectionTest:
@pytest.fixture
def fxt_dataset(self) -> Dataset:
return Dataset.from_iterable(
[
DatasetItem(
id="item",
media=Image.from_numpy(np.zeros([2, 2, 3])),
annotations=[
Bbox(0, 0, 1, 1, label=0),
Bbox(1, 0, 1, 1, label=1),
Bbox(0, 1, 1, 1, label=2),
],
)
],
categories={
AnnotationType.label: LabelCategories.from_iterable(
[f"label_{label_id}" for label_id in range(3)]
),
},
)

@pytest.fixture
def fxt_launcher(self) -> MagicMock:
gt_overlapped = Bbox(0, 0, 1, 1, label=1, attributes={"score": 1.0})
missing_label = Bbox(1, 1, 1, 1, label=1, attributes={"score": 0.5})

launcher_mock = MagicMock()
launcher_mock.categories.return_value = None
launcher_mock.launch.return_value = [[gt_overlapped, missing_label]]
return launcher_mock

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
@pytest.mark.parametrize("label_agnostic_matching, n_anns_expected", [(True, 1), (False, 2)])
def test_label_matching_flags(
self,
fxt_dataset: Dataset,
fxt_launcher: MagicMock,
label_agnostic_matching: bool,
n_anns_expected: int,
):
with eager_mode():
dataset = fxt_dataset.transform(
"missing_annotation_detection",
launcher=fxt_launcher,
label_agnostic_matching=label_agnostic_matching,
)

item = dataset.get(id="item")

assert len(item.annotations) == n_anns_expected

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
@pytest.mark.parametrize("score_threshold, n_anns_expected", [(0.6, 0), (0.4, 1)])
def test_score_threshold(
self,
fxt_dataset: Dataset,
fxt_launcher: MagicMock,
score_threshold: float,
n_anns_expected: int,
):
with eager_mode():
dataset = fxt_dataset.transform(
"missing_annotation_detection",
launcher=fxt_launcher,
score_threshold=score_threshold,
label_agnostic_matching=True,
)

item = dataset.get(id="item")

assert len(item.annotations) == n_anns_expected