Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

format_detection: add a dedicated function to report unsupported detection #665

Merged
merged 1 commit into from
Feb 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Support for downloading the ImageNetV2 and COCO datasets
(<https://github.com/openvinotoolkit/datumaro/pull/653>,
<https://github.com/openvinotoolkit/datumaro/pull/659>)
- A way for formats to signal that they don't support detection
(<https://github.com/openvinotoolkit/datumaro/pull/665>)

### Changed
- Allowed direct file paths in `datum import`. Such sources are imported like
Expand Down
67 changes: 52 additions & 15 deletions datumaro/components/format_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,17 @@ class FormatDetectionConfidence(IntEnum):
# * It makes sure that every confidence level is a true value.
assert all(level > 0 for level in FormatDetectionConfidence)

class FormatRequirementsUnmet(Exception):
class RejectionReason(Enum):
unmet_requirements = auto()
insufficient_confidence = auto()
detection_unsupported = auto()

class _FormatRejected(Exception):
@property
def reason(self) -> RejectionReason:
raise NotImplementedError

class FormatRequirementsUnmet(_FormatRejected):
"""
Represents a situation where a dataset does not meet the requirements
of a given dataset format.
Expand Down Expand Up @@ -74,6 +84,26 @@ def __str__(self) -> str:

return '\n'.join(lines)

@property
def reason(self) -> RejectionReason:
return RejectionReason.unmet_requirements

class FormatDetectionUnsupported(_FormatRejected):
"""
Represents a situation where detection is attempted with a format that
does not support it.

Must not be constructed or raised directly; use
`FormatDetectionContext.raise_unsupported` instead.
"""

def __str__(self) -> str:
return "Detection for this format is unsupported"

@property
def reason(self) -> RejectionReason:
return RejectionReason.detection_unsupported

class FormatDetectionContext:
"""
An instance of this class is given to a dataset format detector.
Expand Down Expand Up @@ -138,6 +168,14 @@ def _start_requirement(self, req_type: str) -> None:
f"a requirement ({req_type}) can't be placed directly within " \
"a 'require_any' block"

def raise_unsupported(self) -> NoReturn:
"""
Raises a `FormatDetectionUnsupported` exception to signal that the
current format does not support detection.
"""

raise FormatDetectionUnsupported

def fail(self, requirement_desc: str) -> NoReturn:
"""
Places a requirement that is never met. `requirement_desc` must contain
Expand Down Expand Up @@ -236,7 +274,7 @@ def probe_text_file(
try:
with open(osp.join(self._root_path, path), encoding='utf-8') as f:
yield f
except FormatRequirementsUnmet:
except _FormatRejected:
raise
except Exception:
self.fail(requirement_desc_full)
Expand Down Expand Up @@ -326,10 +364,14 @@ def alternative(self) -> Iterator[None]:
methods on that instance to place requirements that the dataset must meet
in order for it to be considered as belonging to the format.

Must return the level of confidence in the dataset belonging to the format
(or `None`, which is equivalent to the `MEDIUM` level)
or terminate via a `FormatRequirementsUnmet` exception raised by one of
the `FormatDetectionContext` methods.
Must terminate in one of the following ways:

* by returning the level of confidence in the dataset belonging to the format
(or `None`, which is equivalent to the `MEDIUM` level);
* by raising a `FormatRequirementsUnmet` exception via one of
the `FormatDetectionContext` methods;
* by raising a `FormatDetectionUnsupported` exception via
`FormatDetectionContext.raise_unsupported`.
"""

def apply_format_detector(
Expand All @@ -338,7 +380,8 @@ def apply_format_detector(
"""
Checks whether the dataset located at `dataset_root_path` belongs to the
format detected by `detector`. If it does, returns the confidence level
of the detection. Otherwise, raises a `FormatRequirementsUnmet` exception.
of the detection. Otherwise, terminates with the exception that was raised
by the detector.
"""
context = FormatDetectionContext(dataset_root_path)

Expand All @@ -347,10 +390,6 @@ def apply_format_detector(

return detector(context) or FormatDetectionConfidence.MEDIUM

class RejectionReason(Enum):
unmet_requirements = auto()
insufficient_confidence = auto()

class RejectionCallback(Protocol):
def __call__(self,
format_name: str, reason: RejectionReason, human_message: str,
Expand Down Expand Up @@ -405,12 +444,10 @@ def report_insufficient_confidence(
log.debug("Checking '%s' format...", format_name)
try:
new_confidence = apply_format_detector(path, detector)
except FormatRequirementsUnmet as ex:
except _FormatRejected as ex:
human_message = str(ex)
if rejection_callback:
rejection_callback(
format_name, RejectionReason.unmet_requirements,
human_message)
rejection_callback(format_name, ex.reason, human_message)
log.debug(human_message)
else:
log.debug("Format matched with confidence %d", new_confidence)
Expand Down
4 changes: 2 additions & 2 deletions datumaro/plugins/coco_format/importer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2019-2021 Intel Corporation
# Copyright (C) 2019-2022 Intel Corporation
#
# SPDX-License-Identifier: MIT

Expand Down Expand Up @@ -48,7 +48,7 @@ def detect(
# to use autodetection with the COCO dataset), disable autodetection
# for the single-task formats.
if len(cls._TASKS) == 1:
context.fail('this format cannot be autodetected')
context.raise_unsupported()

with context.require_any():
for task in cls._TASKS.keys():
Expand Down
4 changes: 2 additions & 2 deletions datumaro/plugins/voc_format/importer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2019-2021 Intel Corporation
# Copyright (C) 2019-2022 Intel Corporation
#
# SPDX-License-Identifier: MIT

Expand Down Expand Up @@ -26,7 +26,7 @@ def detect(cls, context: FormatDetectionContext) -> None:
# possible to use autodetection with the VOC datasets), disable
# autodetection for the single-task formats.
if len(cls._TASKS) == 1:
context.fail('this format cannot be autodetected')
context.raise_unsupported()

with context.require_any():
task_dirs = {task_dir for _, task_dir in cls._TASKS.values()}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ The format of the machine-readable report is as follows:

The `<reason-code>` can be one of:

- `"detection_unsupported"`: the corresponding format does not support
detection.

- `"insufficient_confidence"`: the dataset matched the corresponding format,
but it matched at least one other format better.

Expand Down
20 changes: 17 additions & 3 deletions tests/test_format_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import os.path as osp

from datumaro.components.format_detection import (
FormatDetectionConfidence, FormatRequirementsUnmet, RejectionReason,
apply_format_detector, detect_dataset_format,
FormatDetectionConfidence, FormatDetectionUnsupported,
FormatRequirementsUnmet, RejectionReason, apply_format_detector,
detect_dataset_format,
)
from datumaro.util.test_utils import TestDir

Expand Down Expand Up @@ -202,6 +203,14 @@ def detect(context):
self.assertEqual(result.exception.failed_alternatives,
('bad alternative 1', 'bad alternative 2'))

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_raise_unsupported(self):
def detect(context):
context.raise_unsupported()
Comment on lines +208 to +209
Copy link
Contributor

@zhiltsov-max zhiltsov-max Feb 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like it should be the default implementation (in Importer or Environment), shouldn't it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered this, but I think the current default implementation (falling back to find_sources) is more useful.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When it's supposed to be called then?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added calls to it in this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I expect it to be redesigned quite soon though.


with self.assertRaises(FormatDetectionUnsupported):
apply_format_detector(self._dataset_root, detect)

class DetectDatasetFormat(FormatDetectionTest):
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_no_input_formats(self):
Expand All @@ -218,6 +227,7 @@ def test_general_case(self):
# ddd should be rejected immediately
("ddd", lambda context: FormatDetectionConfidence.LOW),
("eee", lambda context: None),
("fff", lambda context: context.raise_unsupported()),
]

rejected_formats = {}
Expand All @@ -230,7 +240,8 @@ def rejection_callback(format, reason, message):

self.assertEqual(set(detected_datasets), {"bbb", "eee"})

self.assertEqual(rejected_formats.keys(), {"aaa", "ccc", "ddd"})
self.assertEqual(rejected_formats.keys(), {"aaa", "ccc", "ddd", "fff"})

for name in ("aaa", "ddd"):
self.assertEqual(rejected_formats[name][0],
RejectionReason.insufficient_confidence)
Expand All @@ -239,6 +250,9 @@ def rejection_callback(format, reason, message):
RejectionReason.unmet_requirements)
self.assertIn("test unmet requirement", rejected_formats["ccc"][1])

self.assertEqual(rejected_formats["fff"][0],
RejectionReason.detection_unsupported)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_no_callback(self):
formats = [
Expand Down