diff --git a/test/test_datasets.py b/test/test_datasets.py index 1dc6892c318..f37cf291829 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -782,32 +782,46 @@ def inject_fake_data(self, tmpdir, config): annotation_folder = tmpdir / self._ANNOTATIONS_FOLDER os.makedirs(annotation_folder) + + segmentation_kind = config.pop("segmentation_kind", "list") info = self._create_annotation_file( - annotation_folder, self._ANNOTATIONS_FILE, file_names, num_annotations_per_image + annotation_folder, + self._ANNOTATIONS_FILE, + file_names, + num_annotations_per_image, + segmentation_kind=segmentation_kind, ) info["num_examples"] = num_images return info - def _create_annotation_file(self, root, name, file_names, num_annotations_per_image): + def _create_annotation_file(self, root, name, file_names, num_annotations_per_image, segmentation_kind="list"): image_ids = [int(file_name.stem) for file_name in file_names] images = [dict(file_name=str(file_name), id=id) for file_name, id in zip(file_names, image_ids)] - annotations, info = self._create_annotations(image_ids, num_annotations_per_image) + annotations, info = self._create_annotations(image_ids, num_annotations_per_image, segmentation_kind) self._create_json(root, name, dict(images=images, annotations=annotations)) return info - def _create_annotations(self, image_ids, num_annotations_per_image): + def _create_annotations(self, image_ids, num_annotations_per_image, segmentation_kind="list"): annotations = [] annotion_id = 0 + for image_id in itertools.islice(itertools.cycle(image_ids), len(image_ids) * num_annotations_per_image): + segmentation = { + "list": [torch.rand(8).tolist()], + "rle": {"size": [10, 10], "counts": [1]}, + "rle_encoded": {"size": [2400, 2400], "counts": "PQRQ2[1\\Y2f0gNVNRhMg2"}, + "bad": 123, + }[segmentation_kind] + annotations.append( dict( image_id=image_id, id=annotion_id, bbox=torch.rand(4).tolist(), - segmentation=[torch.rand(8).tolist()], + segmentation=segmentation, category_id=int(torch.randint(91, ())), area=float(torch.rand(1)), iscrowd=int(torch.randint(2, size=(1,))), @@ -832,11 +846,27 @@ def test_slice_error(self): with pytest.raises(ValueError, match="Index must be of type integer"): dataset[:2] + def test_segmentation_kind(self): + if isinstance(self, CocoCaptionsTestCase): + return + + for segmentation_kind in ("list", "rle", "rle_encoded"): + config = {"segmentation_kind": segmentation_kind} + with self.create_dataset(config) as (dataset, _): + dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys="all") + list(dataset) + + config = {"segmentation_kind": "bad"} + with self.create_dataset(config) as (dataset, _): + dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys="all") + with pytest.raises(ValueError, match="COCO segmentation expected to be a dict or a list"): + list(dataset) + class CocoCaptionsTestCase(CocoDetectionTestCase): DATASET_CLASS = datasets.CocoCaptions - def _create_annotations(self, image_ids, num_annotations_per_image): + def _create_annotations(self, image_ids, num_annotations_per_image, segmentation_kind="list"): captions = [str(idx) for idx in range(num_annotations_per_image)] annotations = combinations_grid(image_id=image_ids, caption=captions) for id, annotation in enumerate(annotations): diff --git a/torchvision/tv_tensors/_dataset_wrapper.py b/torchvision/tv_tensors/_dataset_wrapper.py index dcdb128aa76..23683221f60 100644 --- a/torchvision/tv_tensors/_dataset_wrapper.py +++ b/torchvision/tv_tensors/_dataset_wrapper.py @@ -359,11 +359,14 @@ def coco_dectection_wrapper_factory(dataset, target_keys): def segmentation_to_mask(segmentation, *, canvas_size): from pycocotools import mask - segmentation = ( - mask.frPyObjects(segmentation, *canvas_size) - if isinstance(segmentation, dict) - else mask.merge(mask.frPyObjects(segmentation, *canvas_size)) - ) + if isinstance(segmentation, dict): + # if counts is a string, it is already an encoded RLE mask + if not isinstance(segmentation["counts"], str): + segmentation = mask.frPyObjects(segmentation, *canvas_size) + elif isinstance(segmentation, list): + segmentation = mask.merge(mask.frPyObjects(segmentation, *canvas_size)) + else: + raise ValueError(f"COCO segmentation expected to be a dict or a list, got {type(segmentation)}") return torch.from_numpy(mask.decode(segmentation)) def wrapper(idx, sample):