Skip to content

Commit

Permalink
Add tests for the detect_dataset_format function
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Donchenko committed Jan 25, 2022
1 parent 3839ebd commit 7a2eada
Showing 1 changed file with 50 additions and 1 deletion.
51 changes: 50 additions & 1 deletion tests/test_format_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import os.path as osp

from datumaro.components.format_detection import (
FormatDetectionConfidence, FormatRequirementsUnmet, apply_format_detector,
FormatDetectionConfidence, FormatRequirementsUnmet, RejectionReason,
apply_format_detector, detect_dataset_format,
)
from datumaro.util.test_utils import TestDir

Expand All @@ -15,6 +16,7 @@ def setUp(self) -> None:
self._dataset_root = test_dir_context.__enter__()
self.addCleanup(test_dir_context.__exit__, None, None, None)

class ApplyFormatDetectorTest(FormatDetectionTest):
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_empty_detector(self):
result = apply_format_detector(self._dataset_root, lambda c: None)
Expand Down Expand Up @@ -199,3 +201,50 @@ def detect(context):

self.assertEqual(result.exception.failed_alternatives,
('bad alternative 1', 'bad alternative 2'))

class DetectDatasetFormat(FormatDetectionTest):
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_no_input_formats(self):
detected_datasets = detect_dataset_format((), self._dataset_root)
self.assertEqual(detected_datasets, [])

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_general_case(self):
formats = [
# aaa should be rejected after bbb is checked
("aaa", lambda context: FormatDetectionConfidence.LOW),
("bbb", lambda context: None),
("ccc", lambda context: context.fail("test unmet requirement")),
# ddd should be rejected immediately
("ddd", lambda context: FormatDetectionConfidence.LOW),
("eee", lambda context: None),
]

rejected_formats = {}

def rejection_callback(format, reason, message):
rejected_formats[format] = (reason, message)

detected_datasets = detect_dataset_format(formats, self._dataset_root,
rejection_callback=rejection_callback)

self.assertEqual(set(detected_datasets), {"bbb", "eee"})

self.assertEqual(rejected_formats.keys(), {"aaa", "ccc", "ddd"})
for name in ("aaa", "ddd"):
self.assertEqual(rejected_formats[name][0],
RejectionReason.insufficient_confidence)

self.assertEqual(rejected_formats["ccc"][0],
RejectionReason.unmet_requirements)
self.assertIn("test unmet requirement", rejected_formats["ccc"][1])

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_no_callback(self):
formats = [
("bbb", lambda context: None),
("ccc", lambda context: context.fail("test unmet requirement")),
]

detected_datasets = detect_dataset_format(formats, self._dataset_root)
self.assertEqual(detected_datasets, ["bbb"])

0 comments on commit 7a2eada

Please sign in to comment.