diff --git a/CHANGELOG.md b/CHANGELOG.md index f0e966c5fc..2f8971614a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added - Support for escaping in attribiute values in LabelMe format () +- Support for Segmentation Splitting () ### Changed - LabelMe format saves dataset items with their relative paths by subsets without changing names () diff --git a/datumaro/plugins/splitter.py b/datumaro/plugins/splitter.py index c9e19fa8b9..abc391ab19 100644 --- a/datumaro/plugins/splitter.py +++ b/datumaro/plugins/splitter.py @@ -5,6 +5,7 @@ import logging as log import numpy as np from math import gcd +from enum import Enum from datumaro.components.extractor import (Transform, AnnotationType, DEFAULT_SUBSET_NAME) @@ -13,34 +14,173 @@ NEAR_ZERO = 1e-7 +SplitTask = Enum( + "split", ["classification", "detection", "segmentation", "reid"] +) -class _TaskSpecificSplit(Transform, CliPlugin): - _default_split = [('train', 0.5), ('val', 0.2), ('test', 0.3)] + +class Split(Transform, CliPlugin): + """ + - classification split |n + Splits dataset into subsets(train/val/test) in class-wise manner. |n + Splits dataset images in the specified ratio, keeping the initial class + distribution.|n + |n + - detection & segmentation split |n + Each image can have multiple object annotations - + (bbox, mask, polygon). Since an image shouldn't be included + in multiple subsets at the same time, and image annotations + shoudln't be split, in general, dataset annotations are unlikely + to be split exactly in the specified ratio. |n + This split tries to split dataset images as close as possible + to the specified ratio, keeping the initial class distribution.|n + |n + - reidentification split |n + In this task, the test set should consist of images of unseen + people or objects during the training phase. |n + This function splits a dataset in the following way:|n + 1. Splits the dataset into 'train + val' and 'test' sets|n + |s|sbased on person or object ID.|n + 2. Splits 'test' set into 'test-gallery' and 'test-query' sets|n + |s|sin class-wise manner.|n + 3. Splits the 'train + val' set into 'train' and 'val' sets|n + |s|sin the same way.|n + The final subsets would be + 'train', 'val', 'test-gallery' and 'test-query'. |n + |n + Notes:|n + - Each image is expected to have only one Annotation. Unlabeled or + multi-labeled images will be split into subsets randomly. |n + - If Labels also have attributes, also splits by attribute values.|n + - If there is not enough images in some class or attributes group, + the split ratio can't be guaranteed.|n + In reidentification task, |n + - Object ID can be described by Label, or by attribute (--attr parameter)|n + - The splits of the test set are controlled by '--query' parameter |n + |s|sGallery ratio would be 1.0 - query.|n + |n + Example:|n + |s|s%(prog)s -t classification --subset train:.5 --subset val:.2 --subset test:.3 |n + |s|s%(prog)s -t detection --subset train:.5 --subset val:.2 --subset test:.3 |n + |s|s%(prog)s -t segmentation --subset train:.5 --subset val:.2 --subset test:.3 |n + |s|s%(prog)s -t reid --subset train:.5 --subset val:.2 --subset test:.3 --query .5 |n + Example: use 'person_id' attribute for splitting|n + |s|s%(prog)s --attr person_id + """ + + _default_split = [("train", 0.5), ("val", 0.2), ("test", 0.3)] + _default_query_ratio = 0.5 @classmethod def build_cmdline_parser(cls, **kwargs): parser = super().build_cmdline_parser(**kwargs) - parser.add_argument('-s', '--subset', action='append', - type=cls._split_arg, dest='splits', + parser.add_argument( + "-t", + "--task", + default=SplitTask.classification.name, + choices=[t.name for t in SplitTask], + help="(one of {}; default: %(default)s)".format( + ", ".join(t.name for t in SplitTask) + ), + ) + parser.add_argument( + "-s", + "--subset", + action="append", + type=cls._split_arg, + dest="splits", help="Subsets in the form: ':' " - "(repeatable, default: %s)" % dict(cls._default_split)) - parser.add_argument('--seed', type=int, help="Random seed") + "(repeatable, default: %s)" % dict(cls._default_split), + ) + parser.add_argument( + "--query", + type=float, + default=None, + help="Query ratio in the test set (default: %.3f)" + % cls._default_query_ratio, + ) + parser.add_argument( + "--attr", + type=str, + dest="attr_for_id", + default=None, + help="Attribute name representing the ID (default: use label)", + ) + parser.add_argument("--seed", type=int, help="Random seed") return parser @staticmethod def _split_arg(s): - parts = s.split(':') + parts = s.split(":") if len(parts) != 2: import argparse + raise argparse.ArgumentTypeError() return (parts[0], float(parts[1])) - def __init__(self, dataset, splits, seed, restrict=False): + def __init__(self, dataset, task, splits, query=None, attr_for_id=None, seed=None): super().__init__(dataset) if splits is None: splits = self._default_split + self.task = task + self.splitter = self._get_splitter( + task, dataset, splits, seed, query, attr_for_id + ) + self._initialized = False + self._subsets = self.splitter._subsets + + @staticmethod + def _get_splitter(task, dataset, splits, seed, query, attr_for_id): + if task == SplitTask.classification.name: + splitter = _ClassificationSplit(dataset=dataset, splits=splits, seed=seed) + elif task in {SplitTask.detection.name, SplitTask.segmentation.name}: + splitter = _InstanceSpecificSplit( + dataset=dataset, splits=splits, seed=seed, task=task + ) + elif task == SplitTask.reid.name: + splitter = _ReidentificationSplit( + dataset=dataset, + splits=splits, + seed=seed, + query=query, + attr_for_id=attr_for_id, + ) + else: + raise Exception( + f"Unknown task '{task}', available " + f"splitter format: {[a.name for a in SplitTask]}" + ) + return splitter + + def __iter__(self): + # lazy splitting + if self._initialized is False: + self.splitter._split_dataset() + self._initialized = True + for i, item in enumerate(self._extractor): + yield self.wrap_item(item, subset=self.splitter._find_split(i)) + + def get_subset(self, name): + # lazy splitting + if self._initialized is False: + self.splitter._split_dataset() + self._initialized = True + return super().get_subset(name) + + def subsets(self): + # lazy splitting + if self._initialized is False: + self.splitter._split_dataset() + self._initialized = True + return super().subsets() + + +class _TaskSpecificSplit: + def __init__(self, dataset, splits, seed, restrict=False): + self._extractor = dataset + snames, sratio, subsets = self._validate_splits(splits, restrict) self._snames = snames @@ -67,8 +207,7 @@ def _get_uniq_annotations(dataset): unlabeled_or_multi = [] for idx, item in enumerate(dataset): - labels = [a for a in item.annotations - if a.type == AnnotationType.label] + labels = [a for a in item.annotations if a.type == AnnotationType.label] if len(labels) == 1: annotations.append(labels[0]) else: @@ -86,11 +225,16 @@ def _validate_splits(splits, restrict=False): # remove subset name restriction # https://github.com/openvinotoolkit/datumaro/issues/194 if restrict: - assert subset in valid, \ - "Subset name must be one of %s, got %s" % (valid, subset) - assert 0.0 <= ratio and ratio <= 1.0, \ - "Ratio is expected to be in the range " \ - "[0, 1], but got %s for %s" % (ratio, subset) + assert subset in valid, "Subset name must be one of %s, got %s" % ( + valid, + subset, + ) + assert ( + 0.0 <= ratio and ratio <= 1.0 + ), "Ratio is expected to be in the range " "[0, 1], but got %s for %s" % ( + ratio, + subset, + ) # ignore near_zero ratio because it may produce partition error. if ratio > NEAR_ZERO: # handling duplication @@ -185,9 +329,9 @@ def _is_float(value): return by_attributes - def _split_by_attr(self, datasets, snames, ratio, out_splits, - merge_small_classes=True): - + def _split_by_attr( + self, datasets, snames, ratio, out_splits, merge_small_classes=True + ): def _split_indice(indice): sections, _ = self._get_sections(len(indice), ratio) splits = np.array_split(indice, sections) @@ -254,16 +398,8 @@ def _find_split(self, index): def _split_dataset(self): raise NotImplementedError() - def __iter__(self): - # lazy splitting - if self._initialized is False: - self._split_dataset() - self._initialized = True - for i, item in enumerate(self._extractor): - yield self.wrap_item(item, subset=self._find_split(i)) - -class ClassificationSplit(_TaskSpecificSplit): +class _ClassificationSplit(_TaskSpecificSplit): """ Splits dataset into subsets(train/val/test) in class-wise manner. |n Splits dataset images in the specified ratio, keeping the initial class @@ -277,8 +413,9 @@ class ClassificationSplit(_TaskSpecificSplit): the split ratio can't be guaranteed.|n |n Example:|n - |s|s%(prog)s --subset train:.5 --subset val:.2 --subset test:.3 + |s|s%(prog)s -t classification --subset train:.5 --subset val:.2 --subset test:.3 """ + def __init__(self, dataset, splits, seed=None): """ Parameters @@ -300,7 +437,7 @@ def _split_dataset(self): annotations, unlabeled = self._get_uniq_annotations(self._extractor) for idx, ann in enumerate(annotations): - label = getattr(ann, 'label', None) + label = getattr(ann, "label", None) if label not in by_labels: by_labels[label] = [] by_labels[label].append((idx, ann)) @@ -320,7 +457,7 @@ def _split_dataset(self): self._set_parts(by_splits) -class ReidentificationSplit(_TaskSpecificSplit): +class _ReidentificationSplit(_TaskSpecificSplit): """ Splits a dataset for re-identification task.|n Produces a split with a specified ratio of images, avoiding having same @@ -347,25 +484,14 @@ class ReidentificationSplit(_TaskSpecificSplit): |n Example: split a dataset in the specified ratio, split the test set|n |s|s|s|sinto gallery and query in 1:1 ratio|n - |s|s%(prog)s --subset train:.5 --subset val:.2 --subset test:.3 --query .5|n + |s|s%(prog)s -t reidentification --subset train:.5 --subset val:.2 --subset test:.3 --query .5|n Example: use 'person_id' attribute for splitting|n |s|s%(prog)s --attr person_id """ _default_query_ratio = 0.5 - @classmethod - def build_cmdline_parser(cls, **kwargs): - parser = super().build_cmdline_parser(**kwargs) - parser.add_argument('--query', type=float, - help="Query ratio in the test set (default: %.3f)" - % cls._default_query_ratio) - parser.add_argument('--attr', type=str, dest='attr_for_id', - help="Attribute name representing the ID (default: use label)") - return parser - - def __init__(self, dataset, splits, query=None, - attr_for_id=None, seed=None): + def __init__(self, dataset, splits, query=None, attr_for_id=None, seed=None): """ Parameters ---------- @@ -387,10 +513,10 @@ def __init__(self, dataset, splits, query=None, if query is None: query = self._default_query_ratio - assert 0.0 <= query and query <= 1.0, \ - "Query ratio is expected to be in the range " \ - "[0, 1], but got %f" % query - test_splits = [('test-query', query), ('test-gallery', 1.0 - query)] + assert 0.0 <= query and query <= 1.0, ( + "Query ratio is expected to be in the range " "[0, 1], but got %f" % query + ) + test_splits = [("test-query", query), ("test-gallery", 1.0 - query)] # remove subset name restriction self._subsets = {"train", "val", "test-gallery", "test-query"} @@ -410,15 +536,16 @@ def _split_dataset(self): annotations, unlabeled = self._get_uniq_annotations(dataset) if attr_for_id is None: # use label for idx, ann in enumerate(annotations): - ID = getattr(ann, 'label', None) + ID = getattr(ann, "label", None) if ID not in by_id: by_id[ID] = [] by_id[ID].append((idx, ann)) else: # use attr_for_id for idx, ann in enumerate(annotations): attributes = dict(ann.attributes.items()) - assert attr_for_id in attributes, \ + assert attr_for_id in attributes, ( "'%s' is expected as an attribute name" % attr_for_id + ) ID = attributes[attr_for_id] if ID not in by_id: by_id[ID] = [] @@ -426,9 +553,9 @@ def _split_dataset(self): required = self._get_required(id_ratio) if len(by_id) < required: - log.warning("There's not enough IDs, which is %s, " - "so train/val/test ratio can't be guaranteed." - % len(by_id) + log.warning( + "There's not enough IDs, which is %s, " + "so train/val/test ratio can't be guaranteed." % len(by_id) ) # 1. split dataset into trval and test @@ -444,7 +571,9 @@ def _split_dataset(self): trval = {pid: by_id[pid] for pid in splits[1]} # follow the ratio of datasetitems as possible. # naive heuristic: exchange the best item one by one. - expected_count = int(len(self._extractor) * split_ratio[0]) + expected_count = int( + (len(self._extractor) - len(unlabeled)) * split_ratio[0] + ) testset_total = int(np.sum([len(v) for v in testset.values()])) self._rebalancing(testset, trval, expected_count, testset_total) else: @@ -463,8 +592,9 @@ def _split_dataset(self): test_snames.append(sname) test_ratio.append(float(ratio)) - self._split_by_attr(testset, test_snames, test_ratio, by_splits, - merge_small_classes=False) + self._split_by_attr( + testset, test_snames, test_ratio, by_splits, merge_small_classes=False + ) # 3. split 'trval' into 'train' and 'val' trval_snames = ["train", "val"] @@ -479,14 +609,15 @@ def _split_dataset(self): total_ratio = np.sum(trval_ratio) if total_ratio < NEAR_ZERO: trval_splits = list(zip(["train", "val"], trval_ratio)) - log.warning("Sum of ratios is expected to be positive, " - "got %s, which is %s" - % (trval_splits, total_ratio) + log.warning( + "Sum of ratios is expected to be positive, " + "got %s, which is %s" % (trval_splits, total_ratio) ) else: trval_ratio /= total_ratio # normalize - self._split_by_attr(trval, trval_snames, trval_ratio, by_splits, - merge_small_classes=False) + self._split_by_attr( + trval, trval_snames, trval_ratio, by_splits, merge_small_classes=False + ) # split unlabeled data into 'not-supported'. if len(unlabeled) > 0: @@ -541,30 +672,16 @@ def _rebalancing(test, trval, expected_count, testset_total): test[id_trval] = trval.pop(id_trval) trval[id_test] = test.pop(id_test) - def get_subset(self, name): - # lazy splitting - if self._initialized is False: - self._split_dataset() - self._initialized = True - return super().get_subset(name) - - def subsets(self): - # lazy splitting - if self._initialized is False: - self._split_dataset() - self._initialized = True - return super().subsets() - -class DetectionSplit(_TaskSpecificSplit): +class _InstanceSpecificSplit(_TaskSpecificSplit): """ - Splits a dataset into subsets(train/val/test) for detection task, + Splits a dataset into subsets(train/val/test), using object annotations as a basis for splitting.|n Tries to produce an image split with the specified ratio, keeping the initial distribution of class objects.|n |n - In a detection dataset, each image can have multiple object annotations - - instance bounding boxes. Since an image shouldn't be included + each image can have multiple object annotations - + (instance bounding boxes, masks, polygons). Since an image shouldn't be included in multiple subsets at the same time, and image annotations shoudln't be split, in general, dataset annotations are unlikely to be split exactly in the specified ratio. |n @@ -572,14 +689,17 @@ class DetectionSplit(_TaskSpecificSplit): to the specified ratio, keeping the initial class distribution.|n |n Notes:|n - - Each image is expected to have one or more Bbox annotations.|n - - Only Bbox annotations are considered.|n + - Each image is expected to have one or more annotations.|n + - Only bbox annotations are considered in detection task.|n + - Mask or Polygon annotations are considered in segmentation task.|n |n Example: split dataset so that each object class annotations were split|n |s|s|s|sin the specified ratio between subsets|n - |s|s%(prog)s --subset train:.5 --subset val:.2 --subset test:.3 + |s|s%(prog)s -t detection --subset train:.5 --subset val:.2 --subset test:.3 |n + |s|s%(prog)s -t segmentation --subset train:.5 --subset val:.2 --subset test:.3 """ - def __init__(self, dataset, splits, seed=None): + + def __init__(self, dataset, splits, task, seed=None): """ Parameters ---------- @@ -591,18 +711,21 @@ def __init__(self, dataset, splits, seed=None): """ super().__init__(dataset, splits, seed) - @staticmethod - def _group_by_bbox_labels(dataset): + if task == SplitTask.detection.name: + self.annotation_type = [AnnotationType.bbox] + elif task == SplitTask.segmentation.name: + self.annotation_type = [AnnotationType.mask, AnnotationType.polygon] + + def _group_by_labels(self, dataset): by_labels = dict() unlabeled = [] for idx, item in enumerate(dataset): - bbox_anns = [a for a in item.annotations - if a.type == AnnotationType.bbox] + bbox_anns = [a for a in item.annotations if a.type in self.annotation_type] if len(bbox_anns) == 0: unlabeled.append(idx) continue for ann in bbox_anns: - label = getattr(ann, 'label', None) + label = getattr(ann, "label", None) if label not in by_labels: by_labels[label] = [(idx, ann)] else: @@ -615,7 +738,7 @@ def _split_dataset(self): subsets, sratio = self._snames, self._sratio # 1. group by bbox label - by_labels, unlabeled = self._group_by_bbox_labels(self._extractor) + by_labels, unlabeled = self._group_by_labels(self._extractor) # 2. group by attributes required = self._get_required(sratio) @@ -672,7 +795,7 @@ def _split_dataset(self): target_size = dict() expected = [] # expected numbers of per split GT samples for sname, ratio in zip(subsets, sratio): - target_size[sname] = total * ratio + target_size[sname] = (total - len(unlabeled)) * ratio expected.append([sname, np.array(n_combs) * ratio]) # functions for keep the # of annotations not exceed the expected num diff --git a/docs/user_manual.md b/docs/user_manual.md index 06585d36fd..07954edda1 100644 --- a/docs/user_manual.md +++ b/docs/user_manual.md @@ -1035,17 +1035,21 @@ datum transform -t random_split -- --subset train:.67 --subset test:.33 ``` Example: split a dataset in task-specific manner. Supported tasks are -classification, detection, and re-identification. +classification, detection, re-identification and segmentation. ``` bash -datum transform -t classification_split -- \ - --subset train:.5 --subset val:.2 --subset test:.3 +datum transform -t split -- \ + -t classification --subset train:.5 --subset val:.2 --subset test:.3 -datum transform -t detection_split -- \ - --subset train:.5 --subset val:.2 --subset test:.3 +datum transform -t split -- \ + -t detection --subset train:.5 --subset val:.2 --subset test:.3 -datum transform -t reidentification_split -- \ - --subset train:.5 --subset val:.2 --subset test:.3 --query .5 +datum transform -t split -- \ + -t segmentation --subset train:.5 --subset val:.2 --subset test:.3 + +datum transform -t split -- \ + -t reid --subset train:.5 --subset val:.2 --subset test:.3 \ + --query .5 ``` Example: convert polygons to masks, masks to boxes etc.: @@ -1076,7 +1080,7 @@ datum transform -t rename -- -e '|frame_(\d+)|\\1|' Example: Sampling dataset items, subset `train` is divided into `sampled`(sampled_subset) and `unsampled` - `train` has 100 data, and 20 samples are selected. There are `sampled`(20 samples) and 80 `unsampled`(80 datas) subsets. -- Remove `train` subset (if sample_name=`train` or unsample_name=`train`, still remain) +- Remove `train` subset (if sampled_subset=`train` or unsampled_name=`train`, still remain) - There are five methods of sampling the m option. - `topk`: Return the k with high uncertainty data - `lowk`: Return the k with low uncertainty data @@ -1087,9 +1091,9 @@ Example: Sampling dataset items, subset `train` is divided into `sampled`(sample ``` bash datum transform -t sampler -- \ -a entropy \ - -subset_name train \ - -sample_name sampled \ - -unsample_name unsampled \ + -i train \ + -o sampled \ + -u unsampled \ -m topk \ -k 20 ``` diff --git a/tests/test_splitter.py b/tests/test_splitter.py index 351162b9c9..4c233f0eb2 100644 --- a/tests/test_splitter.py +++ b/tests/test_splitter.py @@ -3,8 +3,15 @@ from unittest import TestCase from datumaro.components.project import Dataset -from datumaro.components.extractor import (DatasetItem, Label, Bbox, - LabelCategories, AnnotationType) +from datumaro.components.extractor import ( + DatasetItem, + Label, + Bbox, + Mask, + Polygon, + LabelCategories, + AnnotationType, +) import datumaro.plugins.splitter as splitter from datumaro.components.operations import compute_ann_statistics @@ -40,20 +47,23 @@ def _generate_dataset(self, config): for _ in range(count): idx += 1 iterable.append( - DatasetItem(idx, subset=self._get_subset(idx), - annotations=[ - Label(label_id, attributes=attributes) - ], - image=np.ones((1, 1, 3)) + DatasetItem( + idx, + subset=self._get_subset(idx), + annotations=[Label(label_id, attributes=attributes)], + image=np.ones((1, 1, 3)), ) ) else: for _ in range(counts): idx += 1 iterable.append( - DatasetItem(idx, subset=self._get_subset(idx), + DatasetItem( + idx, + subset=self._get_subset(idx), annotations=[Label(label_id)], - image=np.ones((1, 1, 3))) + image=np.ones((1, 1, 3)), + ) ) categories = {AnnotationType.label: label_cat} dataset = Dataset.from_iterable(iterable, categories) @@ -66,9 +76,10 @@ def test_split_for_classification_multi_class_no_attr(self): "label3": {"attrs": None, "counts": 30}, } source = self._generate_dataset(config) + task = splitter.SplitTask.classification.name splits = [("train", 0.7), ("test", 0.3)] - actual = splitter.ClassificationSplit(source, splits) + actual = splitter.Split(source, task, splits) self.assertEqual(42, len(actual.get_subset("train"))) self.assertEqual(18, len(actual.get_subset("test"))) @@ -91,9 +102,10 @@ def test_split_for_classification_single_class_single_attr(self): counts = {0: 10, 1: 20, 2: 30} config = {"label": {"attrs": ["attr"], "counts": counts}} source = self._generate_dataset(config) + task = splitter.SplitTask.classification.name splits = [("train", 0.7), ("test", 0.3)] - actual = splitter.ClassificationSplit(source, splits) + actual = splitter.Split(source, task, splits) self.assertEqual(42, len(actual.get_subset("train"))) self.assertEqual(18, len(actual.get_subset("test"))) @@ -124,10 +136,11 @@ def test_split_for_classification_single_class_multi_attr(self): attrs = ["attr1", "attr2"] config = {"label": {"attrs": attrs, "counts": counts}} source = self._generate_dataset(config) + task = splitter.SplitTask.classification.name with self.subTest("zero remainder"): splits = [("train", 0.7), ("test", 0.3)] - actual = splitter.ClassificationSplit(source, splits) + actual = splitter.Split(source, task, splits) self.assertEqual(84, len(actual.get_subset("train"))) self.assertEqual(36, len(actual.get_subset("test"))) @@ -152,7 +165,7 @@ def test_split_for_classification_single_class_multi_attr(self): with self.subTest("non-zero remainder"): splits = [("train", 0.95), ("test", 0.05)] - actual = splitter.ClassificationSplit(source, splits) + actual = splitter.Split(source, task, splits) self.assertEqual(114, len(actual.get_subset("train"))) self.assertEqual(6, len(actual.get_subset("test"))) @@ -173,9 +186,10 @@ def test_split_for_classification_multi_label_with_attr(self): "label2": {"attrs": attr2, "counts": counts}, } source = self._generate_dataset(config) + task = splitter.SplitTask.classification.name splits = [("train", 0.7), ("test", 0.3)] - actual = splitter.ClassificationSplit(source, splits) + actual = splitter.Split(source, task, splits) train = actual.get_subset("train") test = actual.get_subset("test") @@ -213,12 +227,10 @@ def test_split_for_classification_multi_label_with_attr(self): self.assertEqual(15, attr_test["attr3"]["distribution"]["2"][0]) with self.subTest("random seed test"): - r1 = splitter.ClassificationSplit(source, splits, seed=1234) - r2 = splitter.ClassificationSplit(source, splits, seed=1234) - r3 = splitter.ClassificationSplit(source, splits, seed=4321) - self.assertEqual( - list(r1.get_subset("test")), list(r2.get_subset("test")) - ) + r1 = splitter.Split(source, task, splits, seed=1234) + r2 = splitter.Split(source, task, splits, seed=1234) + r3 = splitter.Split(source, task, splits, seed=4321) + self.assertEqual(list(r1.get_subset("test")), list(r2.get_subset("test"))) self.assertNotEqual( list(r1.get_subset("test")), list(r3.get_subset("test")) ) @@ -229,8 +241,9 @@ def test_split_for_classification_zero_ratio(self): } source = self._generate_dataset(config) splits = [("train", 0.1), ("val", 0.9), ("test", 0.0)] + task = splitter.SplitTask.classification.name - actual = splitter.ClassificationSplit(source, splits) + actual = splitter.Split(source, task, splits) self.assertEqual(1, len(actual.get_subset("train"))) self.assertEqual(4, len(actual.get_subset("val"))) @@ -241,7 +254,8 @@ def test_split_for_classification_unlabeled(self): iterable = [DatasetItem(i, annotations=[]) for i in range(10)] source = Dataset.from_iterable(iterable, categories=["a", "b"]) splits = [("train", 0.7), ("test", 0.3)] - actual = splitter.ClassificationSplit(source, splits) + task = splitter.SplitTask.classification.name + actual = splitter.Split(source, task, splits) self.assertEqual(7, len(actual.get_subset("train"))) self.assertEqual(3, len(actual.get_subset("test"))) @@ -251,35 +265,41 @@ def test_split_for_classification_unlabeled(self): iterable = [DatasetItem(i, annotations=anns) for i in range(10)] source = Dataset.from_iterable(iterable, categories=["a", "b"]) splits = [("train", 0.7), ("test", 0.3)] - actual = splitter.ClassificationSplit(source, splits) + task = splitter.SplitTask.classification.name + actual = splitter.Split(source, task, splits) self.assertEqual(7, len(actual.get_subset("train"))) self.assertEqual(3, len(actual.get_subset("test"))) def test_split_for_classification_gives_error(self): - source = Dataset.from_iterable([ - DatasetItem(1, annotations=[Label(0)]), - DatasetItem(2, annotations=[Label(1)]), - ], categories=["a", "b", "c"]) + source = Dataset.from_iterable( + [ + DatasetItem(1, annotations=[Label(0)]), + DatasetItem(2, annotations=[Label(1)]), + ], + categories=["a", "b", "c"], + ) + task = splitter.SplitTask.classification.name with self.subTest("wrong ratio"): with self.assertRaisesRegex(Exception, "in the range"): splits = [("train", -0.5), ("test", 1.5)] - splitter.ClassificationSplit(source, splits) + splitter.Split(source, task, splits) with self.assertRaisesRegex(Exception, "Sum of ratios"): splits = [("train", 0.5), ("test", 0.5), ("val", 0.5)] - splitter.ClassificationSplit(source, splits) + splitter.Split(source, task, splits) with self.subTest("duplicated subset name"): with self.assertRaisesRegex(Exception, "duplicated"): splits = [("train", 0.5), ("train", 0.2), ("test", 0.3)] - splitter.ClassificationSplit(source, splits) + splitter.Split(source, task, splits) def test_split_for_reidentification(self): - ''' + """ Test ReidentificationSplit using Dataset with label (ImageNet style) - ''' + """ + def _get_present(stat): values_present = [] for label, dist in stat["distribution"].items(): @@ -303,9 +323,9 @@ def _get_present(stat): attr_for_id = None source = self._generate_dataset(config) splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] + task = splitter.SplitTask.reid.name query = 0.4 / 0.7 - actual = splitter.ReidentificationSplit(source, - splits, query, attr_for_id) + actual = splitter.Split(source, task, splits, query, attr_for_id) stats = dict() for sname in ["train", "val", "test-query", "test-gallery"]: @@ -353,9 +373,9 @@ def _get_present(stat): self.assertEqual(int(total * 0.4 / 0.7), dist_query[pid][0]) def test_split_for_reidentification_randomseed(self): - ''' + """ Test randomseed for reidentification - ''' + """ counts = {} config = dict() for i in range(10): @@ -364,30 +384,28 @@ def test_split_for_reidentification_randomseed(self): counts[label] = count config[label] = {"attrs": None, "counts": count} source = self._generate_dataset(config) + task = splitter.SplitTask.reid.name splits = [("train", 0.5), ("test", 0.5)] query = 0.4 / 0.7 - r1 = splitter.ReidentificationSplit(source, splits, query, seed=1234) - r2 = splitter.ReidentificationSplit(source, splits, query, seed=1234) - r3 = splitter.ReidentificationSplit(source, splits, query, seed=4321) - self.assertEqual( - list(r1.get_subset("train")), list(r2.get_subset("train")) - ) - self.assertNotEqual( - list(r1.get_subset("train")), list(r3.get_subset("train")) - ) + r1 = splitter.Split(source, task, splits, query, seed=1234) + r2 = splitter.Split(source, task, splits, query, seed=1234) + r3 = splitter.Split(source, task, splits, query, seed=4321) + self.assertEqual(list(r1.get_subset("train")), list(r2.get_subset("train"))) + self.assertNotEqual(list(r1.get_subset("train")), list(r3.get_subset("train"))) def test_split_for_reidentification_rebalance(self): - ''' + """ rebalance function shouldn't gives error when there's no exchange - ''' + """ config = dict() for i in range(100): label = "label%03d" % i config[label] = {"attrs": None, "counts": 7} source = self._generate_dataset(config) + task = splitter.SplitTask.reid.name splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] query = 0.4 / 0.7 - actual = splitter.ReidentificationSplit(source, splits, query) + actual = splitter.Split(source, task, splits, query) self.assertEqual(350, len(actual.get_subset("train"))) self.assertEqual(140, len(actual.get_subset("val"))) @@ -396,12 +414,13 @@ def test_split_for_reidentification_rebalance(self): def test_split_for_reidentification_unlabeled(self): query = 0.5 + task = splitter.SplitTask.reid.name with self.subTest("no label"): iterable = [DatasetItem(i, annotations=[]) for i in range(10)] source = Dataset.from_iterable(iterable, categories=["a", "b"]) splits = [("train", 0.6), ("test", 0.4)] - actual = splitter.ReidentificationSplit(source, splits, query) + actual = splitter.Split(source, task, splits, query) self.assertEqual(10, len(actual.get_subset("not-supported"))) with self.subTest("multi label"): @@ -409,12 +428,13 @@ def test_split_for_reidentification_unlabeled(self): iterable = [DatasetItem(i, annotations=anns) for i in range(10)] source = Dataset.from_iterable(iterable, categories=["a", "b"]) splits = [("train", 0.6), ("test", 0.4)] - actual = splitter.ReidentificationSplit(source, splits, query) + actual = splitter.Split(source, task, splits, query) self.assertEqual(10, len(actual.get_subset("not-supported"))) def test_split_for_reidentification_gives_error(self): query = 0.4 / 0.7 # valid query ratio + task = splitter.SplitTask.reid.name counts = {i: (i % 3 + 1) * 7 for i in range(10)} config = {"person": {"attrs": ["PID"], "counts": counts}} @@ -422,35 +442,35 @@ def test_split_for_reidentification_gives_error(self): with self.subTest("wrong ratio"): with self.assertRaisesRegex(Exception, "in the range"): splits = [("train", -0.5), ("val", 0.2), ("test", 0.3)] - splitter.ReidentificationSplit(source, splits, query) + splitter.Split(source, task, splits, query) with self.assertRaisesRegex(Exception, "Sum of ratios"): splits = [("train", 0.6), ("val", 0.2), ("test", 0.3)] - splitter.ReidentificationSplit(source, splits, query) + splitter.Split(source, task, splits, query) with self.assertRaisesRegex(Exception, "in the range"): splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] - actual = splitter.ReidentificationSplit(source, splits, -query) + actual = splitter.Split(source, task, splits, -query) with self.subTest("duplicated subset name"): with self.assertRaisesRegex(Exception, "duplicated"): splits = [("train", 0.5), ("train", 0.2), ("test", 0.3)] - splitter.ReidentificationSplit(source, splits, query) + splitter.Split(source, task, splits, query) with self.subTest("wrong subset name"): with self.assertRaisesRegex(Exception, "Subset name"): splits = [("_train", 0.5), ("val", 0.2), ("test", 0.3)] - splitter.ReidentificationSplit(source, splits, query) + splitter.Split(source, task, splits, query) with self.subTest("wrong attribute name for person id"): splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] - actual = splitter.ReidentificationSplit(source, splits, query) + actual = splitter.Split(source, task, splits, query) with self.assertRaisesRegex(Exception, "Unknown subset"): actual.get_subset("test") - def _generate_detection_dataset(self, **kwargs): - append_bbox = kwargs.get("append_bbox") + def _generate_detection_segmentation_dataset(self, **kwargs): + annotation_type = kwargs.get("annotation_type") with_attr = kwargs.get("with_attr", False) nimages = kwargs.get("nimages", 10) @@ -479,10 +499,18 @@ def _generate_detection_dataset(self, **kwargs): attributes["attr0"] = attr_val % 3 attributes["attr%d" % (label_id + 1)] = attr_val % 2 for ann_id in range(count): - append_bbox(annotations, label_id=label_id, ann_id=ann_id, - attributes=attributes) - item = DatasetItem(img_id, subset=self._get_subset(img_id), - annotations=annotations, attributes={"id": img_id}) + annotation_type( + annotations, + label_id=label_id, + ann_id=ann_id, + attributes=attributes, + ) + item = DatasetItem( + img_id, + subset=self._get_subset(img_id), + annotations=annotations, + attributes={"id": img_id}, + ) iterable.append(item) dataset = Dataset.from_iterable(iterable, categories) @@ -492,7 +520,12 @@ def _generate_detection_dataset(self, **kwargs): def _get_append_bbox(dataset_type): def append_bbox_coco(annotations, **kwargs): annotations.append( - Bbox(1, 1, 2, 2, label=kwargs["label_id"], + Bbox( + 1, + 1, + 2, + 2, + label=kwargs["label_id"], id=kwargs["ann_id"], attributes=kwargs["attributes"], group=kwargs["ann_id"], @@ -504,7 +537,12 @@ def append_bbox_coco(annotations, **kwargs): def append_bbox_voc(annotations, **kwargs): annotations.append( - Bbox(1, 1, 2, 2, label=kwargs["label_id"], + Bbox( + 1, + 1, + 2, + 2, + label=kwargs["label_id"], id=kwargs["ann_id"] + 1, attributes=kwargs["attributes"], group=kwargs["ann_id"], @@ -514,7 +552,12 @@ def append_bbox_voc(annotations, **kwargs): Label(kwargs["label_id"], attributes=kwargs["attributes"]) ) annotations.append( - Bbox(1, 1, 2, 2, label=kwargs["label_id"] + 3, + Bbox( + 1, + 1, + 2, + 2, + label=kwargs["label_id"] + 3, group=kwargs["ann_id"], ) ) # part @@ -530,7 +573,12 @@ def append_bbox_yolo(annotations, **kwargs): def append_bbox_cvat(annotations, **kwargs): annotations.append( - Bbox(1, 1, 2, 2, label=kwargs["label_id"], + Bbox( + 1, + 1, + 2, + 2, + label=kwargs["label_id"], id=kwargs["ann_id"], attributes=kwargs["attributes"], group=kwargs["ann_id"], @@ -543,7 +591,12 @@ def append_bbox_cvat(annotations, **kwargs): def append_bbox_labelme(annotations, **kwargs): annotations.append( - Bbox(1, 1, 2, 2, label=kwargs["label_id"], + Bbox( + 1, + 1, + 2, + 2, + label=kwargs["label_id"], id=kwargs["ann_id"], attributes=kwargs["attributes"], ) @@ -554,7 +607,12 @@ def append_bbox_labelme(annotations, **kwargs): def append_bbox_mot(annotations, **kwargs): annotations.append( - Bbox(1, 1, 2, 2, label=kwargs["label_id"], + Bbox( + 1, + 1, + 2, + 2, + label=kwargs["label_id"], attributes=kwargs["attributes"], ) ) @@ -563,9 +621,7 @@ def append_bbox_mot(annotations, **kwargs): ) def append_bbox_widerface(annotations, **kwargs): - annotations.append( - Bbox(1, 1, 2, 2, attributes=kwargs["attributes"]) - ) + annotations.append(Bbox(1, 1, 2, 2, attributes=kwargs["attributes"])) annotations.append(Label(0, attributes=kwargs["attributes"])) functions = { @@ -581,8 +637,169 @@ def append_bbox_widerface(annotations, **kwargs): func = functions.get(dataset_type, append_bbox_cvat) return func + @staticmethod + def _get_append_mask(dataset_type): + def append_mask_coco(annotations, **kwargs): + annotations.append( + Mask( + np.array([[0, 0, 0, 1, 0]]), + label=kwargs["label_id"], + id=kwargs["ann_id"], + attributes=kwargs["attributes"], + group=kwargs["ann_id"], + ) + ) + annotations.append( + Label(kwargs["label_id"], attributes=kwargs["attributes"]) + ) + + def append_mask_voc(annotations, **kwargs): + annotations.append( + Mask( + np.array([[0, 0, 0, 1, 0]]), + label=kwargs["label_id"], + id=kwargs["ann_id"] + 1, + attributes=kwargs["attributes"], + group=kwargs["ann_id"], + ) + ) # obj + annotations.append( + Label(kwargs["label_id"], attributes=kwargs["attributes"]) + ) + annotations.append( + Mask( + np.array([[0, 0, 0, 1, 0]]), + label=kwargs["label_id"] + 3, + group=kwargs["ann_id"], + ) + ) # part + annotations.append( + Label(kwargs["label_id"] + 3, attributes=kwargs["attributes"]) + ) + + def append_mask_labelme(annotations, **kwargs): + annotations.append( + Mask( + np.array([[0, 0, 0, 1, 0]]), + label=kwargs["label_id"], + id=kwargs["ann_id"], + attributes=kwargs["attributes"], + ) + ) + annotations.append( + Label(kwargs["label_id"], attributes=kwargs["attributes"]) + ) + + def append_mask_mot(annotations, **kwargs): + annotations.append( + Mask( + np.array([[0, 0, 0, 1, 0]]), + label=kwargs["label_id"], + attributes=kwargs["attributes"], + ) + ) + annotations.append( + Label(kwargs["label_id"], attributes=kwargs["attributes"]) + ) + + functions = { + "coco": append_mask_coco, + "voc": append_mask_voc, + "labelme": append_mask_labelme, + "mot": append_mask_mot, + } + + func = functions.get(dataset_type, append_mask_coco) + return func + + @staticmethod + def _get_append_polygon(dataset_type): + def append_polygon_coco(annotations, **kwargs): + annotations.append( + Polygon( + [0, 0, 1, 0, 1, 2, 0, 2], + label=kwargs["label_id"], + id=kwargs["ann_id"], + attributes=kwargs["attributes"], + group=kwargs["ann_id"], + ) + ) + annotations.append( + Label(kwargs["label_id"], attributes=kwargs["attributes"]) + ) + + def append_polygon_voc(annotations, **kwargs): + annotations.append( + Polygon( + [0, 0, 1, 0, 1, 2, 0, 2], + label=kwargs["label_id"], + id=kwargs["ann_id"] + 1, + attributes=kwargs["attributes"], + group=kwargs["ann_id"], + ) + ) # obj + annotations.append( + Label(kwargs["label_id"], attributes=kwargs["attributes"]) + ) + annotations.append( + Polygon( + [0, 0, 1, 0, 1, 2, 0, 2], + label=kwargs["label_id"] + 3, + group=kwargs["ann_id"], + ) + ) # part + annotations.append( + Label(kwargs["label_id"] + 3, attributes=kwargs["attributes"]) + ) + + def append_polygon_yolo(annotations, **kwargs): + annotations.append(Bbox(1, 1, 2, 2, label=kwargs["label_id"])) + annotations.append( + Label(kwargs["label_id"], attributes=kwargs["attributes"]) + ) + + def append_polygon_cvat(annotations, **kwargs): + annotations.append( + Polygon( + [0, 0, 1, 0, 1, 2, 0, 2], + label=kwargs["label_id"], + id=kwargs["ann_id"], + attributes=kwargs["attributes"], + group=kwargs["ann_id"], + z_order=kwargs["ann_id"], + ) + ) + annotations.append( + Label(kwargs["label_id"], attributes=kwargs["attributes"]) + ) + + def append_polygon_labelme(annotations, **kwargs): + annotations.append( + Polygon( + [0, 0, 1, 0, 1, 2, 0, 2], + label=kwargs["label_id"], + id=kwargs["ann_id"], + attributes=kwargs["attributes"], + ) + ) + annotations.append( + Label(kwargs["label_id"], attributes=kwargs["attributes"]) + ) + + functions = { + "coco": append_polygon_coco, + "voc": append_polygon_voc, + "yolo": append_polygon_yolo, + "cvat": append_polygon_cvat, + "labelme": append_polygon_labelme, + } + + func = functions.get(dataset_type, append_polygon_coco) + return func + def test_split_for_detection(self): dtypes = ["coco", "voc", "yolo", "cvat", "labelme", "mot", "widerface"] + task = splitter.SplitTask.detection.name params = [] for dtype in dtypes: for with_attr in [False, True]: @@ -590,8 +807,8 @@ def test_split_for_detection(self): params.append((dtype, with_attr, 10, 7, 0, 3)) for dtype, with_attr, nimages, train, val, test in params: - source, _ = self._generate_detection_dataset( - append_bbox=self._get_append_bbox(dtype), + source, _ = self._generate_detection_segmentation_dataset( + annotation_type=self._get_append_bbox(dtype), with_attr=with_attr, nimages=nimages, ) @@ -608,34 +825,31 @@ def test_split_for_detection(self): train=train, val=val, test=test, + task=task, ): - actual = splitter.DetectionSplit(source, splits) + actual = splitter.Split(source, task, splits) self.assertEqual(train, len(actual.get_subset("train"))) self.assertEqual(val, len(actual.get_subset("val"))) self.assertEqual(test, len(actual.get_subset("test"))) # random seed test - source, _ = self._generate_detection_dataset( - append_bbox=self._get_append_bbox("cvat"), + source, _ = self._generate_detection_segmentation_dataset( + annotation_type=self._get_append_bbox("cvat"), with_attr=True, nimages=10, ) splits = [("train", 0.5), ("test", 0.5)] - r1 = splitter.DetectionSplit(source, splits, seed=1234) - r2 = splitter.DetectionSplit(source, splits, seed=1234) - r3 = splitter.DetectionSplit(source, splits, seed=4321) - self.assertEqual( - list(r1.get_subset("test")), list(r2.get_subset("test")) - ) - self.assertNotEqual( - list(r1.get_subset("test")), list(r3.get_subset("test")) - ) + r1 = splitter.Split(source, task, splits, seed=1234) + r2 = splitter.Split(source, task, splits, seed=1234) + r3 = splitter.Split(source, task, splits, seed=4321) + self.assertEqual(list(r1.get_subset("test")), list(r2.get_subset("test"))) + self.assertNotEqual(list(r1.get_subset("test")), list(r3.get_subset("test"))) def test_split_for_detection_with_unlabeled(self): - source, _ = self._generate_detection_dataset( - append_bbox=self._get_append_bbox("cvat"), + source, _ = self._generate_detection_segmentation_dataset( + annotation_type=self._get_append_bbox("cvat"), with_attr=True, nimages=10, ) @@ -643,42 +857,48 @@ def test_split_for_detection_with_unlabeled(self): source.put(DatasetItem(i + 10, annotations={})) splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] - actual = splitter.DetectionSplit(source, splits) + task = splitter.SplitTask.detection.name + actual = splitter.Split(source, task, splits) self.assertEqual(10, len(actual.get_subset("train"))) self.assertEqual(4, len(actual.get_subset("val"))) self.assertEqual(6, len(actual.get_subset("test"))) def test_split_for_detection_gives_error(self): - source, _ = self._generate_detection_dataset( - append_bbox=self._get_append_bbox("cvat"), + source, _ = self._generate_detection_segmentation_dataset( + annotation_type=self._get_append_bbox("cvat"), with_attr=True, nimages=5, ) + task = splitter.SplitTask.detection.name with self.subTest("wrong ratio"): with self.assertRaisesRegex(Exception, "in the range"): splits = [("train", -0.5), ("test", 1.5)] - splitter.DetectionSplit(source, splits) + splitter.Split(source, task, splits) with self.assertRaisesRegex(Exception, "Sum of ratios"): splits = [("train", 0.5), ("test", 0.5), ("val", 0.5)] - splitter.DetectionSplit(source, splits) + splitter.Split(source, task, splits) with self.subTest("duplicated subset name"): with self.assertRaisesRegex(Exception, "duplicated"): splits = [("train", 0.5), ("train", 0.2), ("test", 0.3)] - splitter.DetectionSplit(source, splits) + splitter.Split(source, task, splits) def test_no_subset_name_and_count_restriction(self): - splits = [("_train", 0.5), ("valid", 0.1), ("valid2", 0.1), - ("test*", 0.2), ("test2", 0.1)] + splits = [ + ("_train", 0.5), + ("valid", 0.1), + ("valid2", 0.1), + ("test*", 0.2), + ("test2", 0.1), + ] with self.subTest("classification"): - config = { - "label1": {"attrs": None, "counts": 10} - } + config = {"label1": {"attrs": None, "counts": 10}} + task = splitter.SplitTask.classification.name source = self._generate_dataset(config) - actual = splitter.ClassificationSplit(source, splits) + actual = splitter.Split(source, task, splits) self.assertEqual(5, len(actual.get_subset("_train"))) self.assertEqual(1, len(actual.get_subset("valid"))) self.assertEqual(1, len(actual.get_subset("valid2"))) @@ -686,14 +906,227 @@ def test_no_subset_name_and_count_restriction(self): self.assertEqual(1, len(actual.get_subset("test2"))) with self.subTest("detection"): - source, _ = self._generate_detection_dataset( - append_bbox=self._get_append_bbox("cvat"), + source, _ = self._generate_detection_segmentation_dataset( + annotation_type=self._get_append_bbox("cvat"), + with_attr=True, + nimages=10, + ) + task = splitter.SplitTask.detection.name + actual = splitter.Split(source, task, splits) + self.assertEqual(5, len(actual.get_subset("_train"))) + self.assertEqual(1, len(actual.get_subset("valid"))) + self.assertEqual(1, len(actual.get_subset("valid2"))) + self.assertEqual(2, len(actual.get_subset("test*"))) + self.assertEqual(1, len(actual.get_subset("test2"))) + + with self.subTest("segmentation"): + source, _ = self._generate_detection_segmentation_dataset( + annotation_type=self._get_append_mask("coco"), + with_attr=True, + nimages=10, + ) + task = splitter.SplitTask.detection.name + actual = splitter.Split(source, task, splits) + self.assertEqual(5, len(actual.get_subset("_train"))) + self.assertEqual(1, len(actual.get_subset("valid"))) + self.assertEqual(1, len(actual.get_subset("valid2"))) + self.assertEqual(2, len(actual.get_subset("test*"))) + self.assertEqual(1, len(actual.get_subset("test2"))) + + source, _ = self._generate_detection_segmentation_dataset( + annotation_type=self._get_append_polygon("coco"), with_attr=True, nimages=10, ) - actual = splitter.DetectionSplit(source, splits) + actual = splitter.Split(source, task, splits) self.assertEqual(5, len(actual.get_subset("_train"))) self.assertEqual(1, len(actual.get_subset("valid"))) self.assertEqual(1, len(actual.get_subset("valid2"))) self.assertEqual(2, len(actual.get_subset("test*"))) self.assertEqual(1, len(actual.get_subset("test2"))) + + def test_split_for_segmentation(self): + + with self.subTest("mask annotation"): + dtypes = ["coco", "voc", "labelme", "mot"] + task = splitter.SplitTask.segmentation.name + params = [] + for dtype in dtypes: + for with_attr in [False, True]: + params.append((dtype, with_attr, 10, 5, 3, 2)) + params.append((dtype, with_attr, 10, 7, 0, 3)) + + for dtype, with_attr, nimages, train, val, test in params: + source, _ = self._generate_detection_segmentation_dataset( + annotation_type=self._get_append_mask(dtype), + with_attr=with_attr, + nimages=nimages, + ) + total = np.sum([train, val, test]) + splits = [ + ("train", train / total), + ("val", val / total), + ("test", test / total), + ] + with self.subTest( + dtype=dtype, + with_attr=with_attr, + nimage=nimages, + train=train, + val=val, + test=test, + task=task, + ): + actual = splitter.Split(source, task, splits) + + self.assertEqual(train, len(actual.get_subset("train"))) + self.assertEqual(val, len(actual.get_subset("val"))) + self.assertEqual(test, len(actual.get_subset("test"))) + + # random seed test + source, _ = self._generate_detection_segmentation_dataset( + annotation_type=self._get_append_mask("coco"), + with_attr=True, + nimages=10, + ) + + splits = [("train", 0.5), ("test", 0.5)] + r1 = splitter.Split(source, task, splits, seed=1234) + r2 = splitter.Split(source, task, splits, seed=1234) + r3 = splitter.Split(source, task, splits, seed=4321) + self.assertEqual(list(r1.get_subset("test")), list(r2.get_subset("test"))) + self.assertNotEqual( + list(r1.get_subset("test")), list(r3.get_subset("test")) + ) + + with self.subTest("polygon annotation"): + dtypes = ["coco", "voc", "labelme", "yolo", "cvat"] + task = splitter.SplitTask.segmentation.name + params = [] + for dtype in dtypes: + for with_attr in [False, True]: + params.append((dtype, with_attr, 10, 5, 3, 2)) + params.append((dtype, with_attr, 10, 7, 0, 3)) + + for dtype, with_attr, nimages, train, val, test in params: + source, _ = self._generate_detection_segmentation_dataset( + annotation_type=self._get_append_polygon(dtype), + with_attr=with_attr, + nimages=nimages, + ) + total = np.sum([train, val, test]) + splits = [ + ("train", train / total), + ("val", val / total), + ("test", test / total), + ] + with self.subTest( + dtype=dtype, + with_attr=with_attr, + nimage=nimages, + train=train, + val=val, + test=test, + task=task, + ): + actual = splitter.Split(source, task, splits) + + self.assertEqual(train, len(actual.get_subset("train"))) + self.assertEqual(val, len(actual.get_subset("val"))) + self.assertEqual(test, len(actual.get_subset("test"))) + + # random seed test + source, _ = self._generate_detection_segmentation_dataset( + annotation_type=self._get_append_polygon("coco"), + with_attr=True, + nimages=10, + ) + + splits = [("train", 0.5), ("test", 0.5)] + r1 = splitter.Split(source, task, splits, seed=1234) + r2 = splitter.Split(source, task, splits, seed=1234) + r3 = splitter.Split(source, task, splits, seed=4321) + self.assertEqual(list(r1.get_subset("test")), list(r2.get_subset("test"))) + self.assertNotEqual( + list(r1.get_subset("test")), list(r3.get_subset("test")) + ) + + def test_split_for_segmentation_with_unlabeled(self): + + with self.subTest("mask annotation"): + source, _ = self._generate_detection_segmentation_dataset( + annotation_type=self._get_append_mask("coco"), + with_attr=True, + nimages=10, + ) + for i in range(10): + source.put(DatasetItem(i + 10, annotations={})) + + splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] + task = splitter.SplitTask.segmentation.name + actual = splitter.Split(source, task, splits) + self.assertEqual(10, len(actual.get_subset("train"))) + self.assertEqual(4, len(actual.get_subset("val"))) + self.assertEqual(6, len(actual.get_subset("test"))) + + with self.subTest("polygon annotation"): + source, _ = self._generate_detection_segmentation_dataset( + annotation_type=self._get_append_polygon("coco"), + with_attr=True, + nimages=10, + ) + for i in range(10): + source.put(DatasetItem(i + 10, annotations={})) + + splits = [("train", 0.5), ("val", 0.2), ("test", 0.3)] + task = splitter.SplitTask.segmentation.name + actual = splitter.Split(source, task, splits) + self.assertEqual(10, len(actual.get_subset("train"))) + self.assertEqual(4, len(actual.get_subset("val"))) + self.assertEqual(6, len(actual.get_subset("test"))) + + def test_split_for_segmentation_gives_error(self): + + with self.subTest("mask annotation"): + source, _ = self._generate_detection_segmentation_dataset( + annotation_type=self._get_append_mask("coco"), + with_attr=True, + nimages=5, + ) + task = splitter.SplitTask.segmentation.name + + with self.subTest("wrong ratio"): + with self.assertRaisesRegex(Exception, "in the range"): + splits = [("train", -0.5), ("test", 1.5)] + splitter.Split(source, task, splits) + + with self.assertRaisesRegex(Exception, "Sum of ratios"): + splits = [("train", 0.5), ("test", 0.5), ("val", 0.5)] + splitter.Split(source, task, splits) + + with self.subTest("duplicated subset name"): + with self.assertRaisesRegex(Exception, "duplicated"): + splits = [("train", 0.5), ("train", 0.2), ("test", 0.3)] + splitter.Split(source, task, splits) + + with self.subTest("polygon annotation"): + source, _ = self._generate_detection_segmentation_dataset( + annotation_type=self._get_append_polygon("coco"), + with_attr=True, + nimages=5, + ) + task = splitter.SplitTask.segmentation.name + + with self.subTest("wrong ratio"): + with self.assertRaisesRegex(Exception, "in the range"): + splits = [("train", -0.5), ("test", 1.5)] + splitter.Split(source, task, splits) + + with self.assertRaisesRegex(Exception, "Sum of ratios"): + splits = [("train", 0.5), ("test", 0.5), ("val", 0.5)] + splitter.Split(source, task, splits) + + with self.subTest("duplicated subset name"): + with self.assertRaisesRegex(Exception, "duplicated"): + splits = [("train", 0.5), ("train", 0.2), ("test", 0.3)] + splitter.Split(source, task, splits)